mirror of https://github.com/hpcaitech/ColossalAI
[elixir] refactored the chunk module (#3956)
parent
86ff5c152b
commit
c173a69b3e
|
@ -1,2 +1,7 @@
|
|||
from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
|
||||
from .core import BlockSpec, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
|
||||
from .fetcher import ChunkFetcher
|
||||
|
||||
__all__ = [
|
||||
'BlockSpec', 'Chunk', 'ChunkGroup', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState',
|
||||
'ChunkFetcher'
|
||||
]
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from .chunk import Chunk
|
||||
from .group import ChunkGroup
|
||||
from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .memory_pool import BlockSpec, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .states import TensorState
|
||||
|
||||
__all__ = [
|
||||
'Chunk', 'ChunkGroup', 'BlockSpec', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState'
|
||||
]
|
||||
|
|
|
@ -8,8 +8,8 @@ from torch.distributed import ProcessGroup
|
|||
from colossalai.elixir.cuda import gpu_device
|
||||
from colossalai.elixir.tensor import FakeTensor
|
||||
|
||||
from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock
|
||||
from .states import TensorState, ts_update_sanity_check
|
||||
from .memory_pool import MemoryPool, TensorBlock
|
||||
from .states import TensorState, validate_tensor_state_update
|
||||
|
||||
|
||||
class ChunkFullError(Exception):
|
||||
|
@ -383,7 +383,11 @@ class Chunk:
|
|||
prev_state = self.tensors_info[tensor].state
|
||||
if prev_state == tensor_state:
|
||||
return
|
||||
if ts_update_sanity_check(prev_state, tensor_state):
|
||||
|
||||
# validate whether the update is legal
|
||||
# if illegal, raise an exception
|
||||
is_update_valid = validate_tensor_state_update(prev_state, tensor_state, raise_exception=True)
|
||||
if is_update_valid:
|
||||
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
|
||||
|
||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||
|
|
|
@ -129,7 +129,7 @@ class ChunkGroup(object):
|
|||
"""Check whether the rcache has enough blocks to store the gathered chunk."""
|
||||
if chunk.rcache_fused:
|
||||
return True
|
||||
return self.rcache.public_free_cnt > 0
|
||||
return self.rcache.public_free_count > 0
|
||||
|
||||
def access_chunk(self, chunk: Chunk) -> bool:
|
||||
"""Access a chunk into rCache."""
|
||||
|
@ -141,7 +141,7 @@ class ChunkGroup(object):
|
|||
if chunk.rcache_fused:
|
||||
block = None
|
||||
else:
|
||||
block = self.rcache.get_public_block()
|
||||
block = self.rcache.pop_public_block()
|
||||
chunk.access_chunk(block)
|
||||
self.__add_to_accset(chunk)
|
||||
return True
|
||||
|
|
|
@ -1,35 +1,52 @@
|
|||
from abc import ABC
|
||||
from collections import defaultdict
|
||||
from enum import Enum
|
||||
from typing import Iterable, NamedTuple
|
||||
|
||||
import torch
|
||||
from torch.autograd.profiler_util import _format_memory
|
||||
|
||||
|
||||
class BlockRequire(NamedTuple):
|
||||
class BlockSpec(NamedTuple):
|
||||
"""
|
||||
BlockSpec is the specification of a block. It contains the number of elements and the data type of the block.
|
||||
|
||||
Args:
|
||||
numel (int): the number of elements in the block
|
||||
dtype (torch.dtype): the data type of the block
|
||||
"""
|
||||
numel: int
|
||||
dtype: torch.dtype
|
||||
|
||||
|
||||
class BlockType(Enum):
|
||||
"""
|
||||
BlockType is the type of a block. There are two types of blocks: public and private.
|
||||
"""
|
||||
PUBLIC = 0
|
||||
PRIVATE = 1
|
||||
|
||||
|
||||
class TensorBlock(ABC):
|
||||
"""TensorBlock is the memory unit of memory pool.
|
||||
It is a continuous memory block used to store tensors.
|
||||
"""
|
||||
TensorBlock is the memory unit of memory pool. It is a contiguous memory block used to store tensors.
|
||||
Each chunk needs a corresponding TensorBlock to store its data during training.
|
||||
|
||||
args:
|
||||
numel: the number of elements in the block
|
||||
dtype: the data type of the block
|
||||
device_type: the device type of the block
|
||||
size (int): the number of elements in the block
|
||||
dtype (torch.dtype): the data type of the block
|
||||
device_type (str): the device type of the block
|
||||
"""
|
||||
total_count: int = 0
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
def __init__(self, size: int, dtype: torch.dtype, device_type: str, block_type: BlockType) -> None:
|
||||
self.block_id = TensorBlock.total_count
|
||||
TensorBlock.total_count += 1
|
||||
|
||||
self.device_type = device_type
|
||||
self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type)
|
||||
self.memo_occ: int = self.payload.numel() * self.payload.element_size()
|
||||
self.payload: torch.Tensor = torch.empty((size,), dtype=dtype, device=device_type)
|
||||
self.size_in_bytes: int = self.payload.numel() * self.payload.element_size()
|
||||
self.block_type = block_type
|
||||
|
||||
@property
|
||||
def numel(self):
|
||||
|
@ -50,122 +67,145 @@ class TensorBlock(ABC):
|
|||
return self.block_id == other.block_id
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})'
|
||||
return f'{self.block_type}(\n\tID = {self.block_id}, \n\tsize = {self.numel}, \n\tdevice = {self.device_type}, \n\tdtype = {self.dtype}, \n\tsize in bytes={self.size_in_bytes}\n)'
|
||||
|
||||
|
||||
class PublicBlock(TensorBlock):
|
||||
"""Public blocks have the same length.
|
||||
Chunks of the same length can share the same public block.
|
||||
"""
|
||||
Public blocks have the same length. Chunks of the same length can share the same public block.
|
||||
"""
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
super().__init__(numel, dtype, device_type)
|
||||
self.block_type = 'public'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'PublicBlock{super().__repr__()}'
|
||||
super().__init__(numel, dtype, device_type, BlockType.PUBLIC)
|
||||
|
||||
|
||||
class PrivateBlock(TensorBlock):
|
||||
"""Private blocks may have different lengths.
|
||||
Each private chunk should use its own private block.
|
||||
"""
|
||||
Private blocks may have different lengths. Each private chunk should use its own private block.
|
||||
"""
|
||||
|
||||
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
|
||||
super().__init__(numel, dtype, device_type)
|
||||
self.block_type = 'private'
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'PrivateBlock{super().__repr__()}'
|
||||
super().__init__(numel, dtype, device_type, BlockType.PRIVATE)
|
||||
|
||||
|
||||
class MemoryPool(object):
|
||||
"""A memory pool consists of public blocks and private blocks.
|
||||
"""
|
||||
A memory pool consists of public blocks and private blocks.
|
||||
rCache uses memory pool to manage memory bolcks.
|
||||
Users should allocate memory blocks before using it.
|
||||
|
||||
args:
|
||||
device_type: the device type of the memory pool
|
||||
device_type (str): the device type of the memory pool
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str) -> None:
|
||||
assert device_type in [
|
||||
'cuda', 'cpu'
|
||||
], f'Expected device type to be cuda or cpu, but got an invalid device type: {device_type}'
|
||||
self.device_type: str = device_type
|
||||
|
||||
# public space
|
||||
# public space = number of public block x the public block size in bytes
|
||||
# all public blocks have the same block size
|
||||
self.public_space: int = 0
|
||||
self.public_block_size: int = 0
|
||||
self.public_dtype: torch.dtype = None
|
||||
|
||||
self.public_free_blocks: list = None
|
||||
self.public_used_blocks: set = None
|
||||
|
||||
self.public_free_cnt: int = 0
|
||||
self.public_used_cnt: int = 0
|
||||
# create block holder and counter
|
||||
self.public_free_blocks: list = list()
|
||||
self.public_used_blocks: set = set()
|
||||
self.public_free_count: int = 0
|
||||
self.public_used_count: int = 0
|
||||
|
||||
# private space
|
||||
# private look up dict returns an empty list if the block is not found
|
||||
self.private_space: int = 0
|
||||
self.private_blocks: list = None
|
||||
self.private_lookup_dict: dict[BlockRequire, list] = None
|
||||
self.private_blocks: list = list()
|
||||
self.private_lookup_dict: dict[BlockSpec, list] = defaultdict(list)
|
||||
|
||||
self.__allocate_flag = False
|
||||
# flags for block allcation
|
||||
self.__public_allocated_flag = False
|
||||
self.__private_allocated_flag = False
|
||||
|
||||
def allocate(self,
|
||||
public_dtype: torch.dtype = torch.float,
|
||||
public_block_size: int = 1024,
|
||||
public_block_number: int = 0,
|
||||
private_block_list: Iterable[BlockRequire] = ()):
|
||||
assert self.__allocate_flag is False
|
||||
assert public_block_number >= 0
|
||||
def allocate_public_blocks(self, block_num: int, block_spec: BlockSpec = None):
|
||||
"""
|
||||
Allocate public tensor blocks for the memory pool. This method will allocate public_block_number blocks with size equal to public_block_size.
|
||||
"""
|
||||
assert not self.__public_allocated_flag, 'Public blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
|
||||
assert block_num >= 0, f'Expected public_block_number >= 0, but got {block_num}'
|
||||
|
||||
self.public_free_blocks = list()
|
||||
self.public_used_blocks = set()
|
||||
for _ in range(public_block_number):
|
||||
block = PublicBlock(public_block_size, public_dtype, self.device_type)
|
||||
if block_spec is None:
|
||||
block_spec = BlockSpec(numel=1024, dtype=torch.float)
|
||||
|
||||
# allocate public blocks
|
||||
for _ in range(block_num):
|
||||
block = PublicBlock(block_spec.numel, block_spec.dtype, self.device_type)
|
||||
self.public_free_blocks.append(block)
|
||||
self.public_space += block.size_in_bytes
|
||||
self.public_free_count += 1
|
||||
|
||||
if public_block_number <= 0:
|
||||
self.public_space = 0
|
||||
else:
|
||||
self.public_space = self.public_free_blocks[0].memo_occ * public_block_number
|
||||
self.public_block_size = public_block_size
|
||||
self.public_dtype = public_dtype
|
||||
# store the block spec info
|
||||
self.public_block_size = block_spec.numel
|
||||
self.public_dtype = block_spec.dtype
|
||||
|
||||
self.public_free_cnt = public_block_number
|
||||
self.public_used_cnt = 0
|
||||
def allocate_private_blocks(self, block_specs: Iterable[BlockSpec]):
|
||||
"""
|
||||
Allocate private blocks for the memory pool. This method will allocate private blocks according to the block_specs.
|
||||
|
||||
self.private_space = 0
|
||||
self.private_blocks = list()
|
||||
self.private_lookup_dict = defaultdict(list)
|
||||
Args:
|
||||
block_specs (Iterable[BlockSpec]): the block specs of the private blocks to be allocated
|
||||
"""
|
||||
# allocate private blocks
|
||||
assert not self.__private_allocated_flag, 'Private blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
|
||||
|
||||
for require in private_block_list:
|
||||
block = PrivateBlock(require.numel, require.dtype, self.device_type)
|
||||
self.private_space += block.memo_occ
|
||||
for spec in block_specs:
|
||||
block = PrivateBlock(spec.numel, spec.dtype, self.device_type)
|
||||
self.private_space += block.size_in_bytes
|
||||
self.private_blocks.append(block)
|
||||
self.private_lookup_dict[require].append(block)
|
||||
self.private_lookup_dict[spec].append(block)
|
||||
|
||||
self.__allocate_flag = True
|
||||
self.__private_allocated_flag = True
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})'
|
||||
return f'Memory Pool(\n\tpublic_space = {_format_memory(self.public_space)}, \n\tprivate_space={_format_memory(self.private_space)}\n)'
|
||||
|
||||
def get_private_block(self, numel: int, dtype: torch.dtype):
|
||||
block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype))
|
||||
return block_list.pop()
|
||||
def get_private_block(self, numel: int, dtype: torch.dtype) -> PrivateBlock:
|
||||
"""
|
||||
Get a private block with the given numel and dtype.
|
||||
"""
|
||||
block_list = self.private_lookup_dict.get(BlockSpec(numel=numel, dtype=dtype))
|
||||
|
||||
def get_public_block(self):
|
||||
self.public_free_cnt -= 1
|
||||
self.public_used_cnt += 1
|
||||
if len(block_list) == 0:
|
||||
raise ValueError(f'No private block with numel={numel} and dtype={dtype} is found.')
|
||||
else:
|
||||
return block_list.pop()
|
||||
|
||||
def pop_public_block(self) -> PublicBlock:
|
||||
"""
|
||||
Get a public block from the memory pool.
|
||||
"""
|
||||
self.public_free_count -= 1
|
||||
self.public_used_count += 1
|
||||
|
||||
block = self.public_free_blocks.pop()
|
||||
self.public_used_blocks.add(block)
|
||||
|
||||
return block
|
||||
|
||||
def free_public_block(self, block: TensorBlock):
|
||||
def free_public_block(self, block: TensorBlock) -> PublicBlock:
|
||||
"""
|
||||
Free a public block to the memory pool.
|
||||
|
||||
Args:
|
||||
block (TensorBlock): the public block to be freed
|
||||
"""
|
||||
assert isinstance(block, PublicBlock)
|
||||
assert block in self.public_used_blocks
|
||||
assert block in self.public_used_blocks, f'Cound not find the given block in the used public blocks'
|
||||
|
||||
self.public_free_cnt += 1
|
||||
self.public_used_cnt -= 1
|
||||
# update counter
|
||||
self.public_free_count += 1
|
||||
self.public_used_count -= 1
|
||||
|
||||
# update free and used blocks
|
||||
self.public_used_blocks.remove(block)
|
||||
self.public_free_blocks.append(block)
|
||||
|
||||
|
|
|
@ -2,6 +2,10 @@ from enum import Enum
|
|||
|
||||
|
||||
class TensorState(Enum):
|
||||
"""
|
||||
TensorState represents the state of a tensor in Elixir.
|
||||
There are five states of a tensor: free, compute, hold, hold_after_bwd, ready_for_reduce.
|
||||
"""
|
||||
FREE = 0
|
||||
COMPUTE = 1
|
||||
HOLD = 2
|
||||
|
@ -9,17 +13,35 @@ class TensorState(Enum):
|
|||
READY_FOR_REDUCE = 4
|
||||
|
||||
|
||||
# expected: free -> hold -> compute -> hold ->
|
||||
# this includes the possible state transition in tensor state:
|
||||
# the item in the list is in the format of (old_state, new_state)
|
||||
# the complete state transtition is:
|
||||
# free -> hold -> compute -> hold ->
|
||||
# -> compute -> hold_after_bwd -> ready_for_reduce
|
||||
legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
|
||||
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
|
||||
LEGAL_TENSOR_STATE_UPDATE_LIST = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||
(TensorState.COMPUTE, TensorState.HOLD),
|
||||
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
|
||||
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
|
||||
|
||||
|
||||
def ts_update_sanity_check(old_state, new_state) -> bool:
|
||||
if (old_state, new_state) not in legal_ts_update_list:
|
||||
raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}')
|
||||
def validate_tensor_state_update(old_state: TensorState, new_state: TensorState, raise_exception: bool = False) -> bool:
|
||||
"""
|
||||
Validate the tensor state update is legal or not.
|
||||
|
||||
Args:
|
||||
old_state (TensorState): the old state of the tensor
|
||||
new_state (TensorState): the new state of the tensor
|
||||
raise_exception (bool, optional): whether to raise exception when the state update is illegal. Defaults to False.
|
||||
|
||||
Returns:
|
||||
bool: whether the state update is legal or not.
|
||||
"""
|
||||
if (old_state, new_state) not in LEGAL_TENSOR_STATE_UPDATE_LIST:
|
||||
if raise_exception:
|
||||
raise RuntimeError(f'Found illegal tensor state updating: {old_state} -> {new_state}')
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .meta_context import MetaContext
|
||||
|
||||
__all__ = ['MetaContext']
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
tensor_creation_methods = dict(tensor=torch.tensor,
|
||||
TESNOR_CREATION_METHODS = dict(tensor=torch.tensor,
|
||||
sparse_coo_tensor=torch.sparse_coo_tensor,
|
||||
asarray=torch.asarray,
|
||||
as_tensor=torch.as_tensor,
|
||||
|
@ -29,4 +29,34 @@ tensor_creation_methods = dict(tensor=torch.tensor,
|
|||
polar=torch.polar,
|
||||
heaviside=torch.heaviside)
|
||||
|
||||
from .meta_ctx import MetaContext
|
||||
|
||||
# TODO: unify this with lazy init context
|
||||
class MetaContext(object):
|
||||
"""A context manager that wraps all tensor creation methods in torch.
|
||||
By default, all tensors will be created in meta.
|
||||
|
||||
args:
|
||||
device_type: The device type of the tensors to be created.
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str = 'meta') -> None:
|
||||
super().__init__()
|
||||
self.device_type = device_type
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def meta_wrap(func):
|
||||
|
||||
def wrapped_func(*args, **kwargs):
|
||||
kwargs['device'] = self.device_type
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
for name, method in TESNOR_CREATION_METHODS.items():
|
||||
setattr(torch, name, meta_wrap(method))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for name, method in TESNOR_CREATION_METHODS.items():
|
||||
setattr(torch, name, method)
|
|
@ -1,34 +0,0 @@
|
|||
import torch
|
||||
|
||||
from colossalai.elixir.ctx import tensor_creation_methods
|
||||
|
||||
|
||||
class MetaContext(object):
|
||||
"""A context manager that wraps all tensor creation methods in torch.
|
||||
By default, all tensors will be created in meta.
|
||||
|
||||
args:
|
||||
device_type: The device type of the tensors to be created.
|
||||
"""
|
||||
|
||||
def __init__(self, device_type: str = 'meta') -> None:
|
||||
super().__init__()
|
||||
self.device_type = device_type
|
||||
return None
|
||||
|
||||
def __enter__(self):
|
||||
|
||||
def meta_wrap(func):
|
||||
|
||||
def wrapped_func(*args, **kwargs):
|
||||
kwargs['device'] = self.device_type
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped_func
|
||||
|
||||
for name, method in tensor_creation_methods.items():
|
||||
setattr(torch, name, meta_wrap(method))
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
for name, method in tensor_creation_methods.items():
|
||||
setattr(torch, name, method)
|
|
@ -1,4 +1,3 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
fused_torch_functions = {F.layer_norm: F.layer_norm}
|
||||
|
@ -12,6 +11,3 @@ def register_fused_layer_norm():
|
|||
except:
|
||||
print('Cannot import fused layer norm, please install apex from source.')
|
||||
pass
|
||||
|
||||
|
||||
register_fused_layer_norm()
|
||||
|
|
|
@ -5,7 +5,7 @@ from typing import List, Tuple
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool
|
||||
from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool
|
||||
from colossalai.elixir.tracer.param_tracer import generate_tf_order
|
||||
from colossalai.elixir.tracer.utils import meta_copy
|
||||
from colossalai.elixir.utils import print_rank_0
|
||||
|
@ -119,7 +119,7 @@ class SearchBase(ABC):
|
|||
for plan in chunk_plans:
|
||||
kwargs = plan.kwargs
|
||||
if kwargs.get('rcache_fused', False):
|
||||
block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype))
|
||||
block_require_list.append(BlockSpec(plan.chunk_size, plan.chunk_dtype))
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_dtype=self.unified_dtype,
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.chunk import BlockSpec, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.chunk.scheduler import FIFOScheduler
|
||||
from colossalai.elixir.hook import BufferStore, HookParam
|
||||
from colossalai.elixir.tensor import OutplaceTensor
|
||||
|
@ -32,13 +32,15 @@ def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher)
|
|||
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
|
||||
pg_size = dist.get_world_size(process_group)
|
||||
|
||||
private_list = list()
|
||||
mp = MemoryPool('cuda')
|
||||
|
||||
# allocate private blocks
|
||||
private_block_specs = list()
|
||||
for param in model.parameters():
|
||||
block_size = to_divide(param.numel(), pg_size)
|
||||
private_list.append(BlockRequire(block_size, param.dtype))
|
||||
private_block_specs.append(BlockSpec(block_size, param.dtype))
|
||||
mp.allocate_private_blocks(private_block_specs)
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(private_block_list=private_list)
|
||||
cg = ChunkGroup(rcache=mp)
|
||||
# allocate chunk group
|
||||
fused_config = dict(rcache_fused=True)
|
||||
|
|
|
@ -1,60 +1,62 @@
|
|||
import torch
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock
|
||||
from colossalai.elixir.chunk import BlockSpec, MemoryPool, PrivateBlock, PublicBlock
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_block():
|
||||
b = PublicBlock(123, torch.float16, 'cuda')
|
||||
payload_b = b.payload
|
||||
# test for public block
|
||||
public_block = PublicBlock(123, torch.float16, 'cuda')
|
||||
public_payload = public_block.payload
|
||||
|
||||
assert payload_b.numel() == 123
|
||||
assert payload_b.dtype == torch.float16
|
||||
assert payload_b.device.type == 'cuda'
|
||||
assert payload_b.numel() * payload_b.element_size() == b.memo_occ
|
||||
assert public_payload.numel() == 123
|
||||
assert public_payload.dtype == torch.float16
|
||||
assert public_payload.device.type == 'cuda'
|
||||
assert public_payload.numel() * public_payload.element_size() == public_block.size_in_bytes
|
||||
|
||||
c = PrivateBlock(77, torch.float, 'cpu')
|
||||
payload_c = c.payload
|
||||
|
||||
assert payload_c.numel() == 77
|
||||
assert payload_c.dtype == torch.float
|
||||
assert payload_c.device.type == 'cpu'
|
||||
assert payload_c.numel() * payload_c.element_size() == c.memo_occ
|
||||
# test for private block
|
||||
private_block = PrivateBlock(77, torch.float, 'cpu')
|
||||
private_payload = private_block.payload
|
||||
|
||||
assert private_payload.numel() == 77
|
||||
assert private_payload.dtype == torch.float
|
||||
assert private_payload.device.type == 'cpu'
|
||||
assert private_payload.numel() * private_payload.element_size() == private_block.size_in_bytes
|
||||
print('test_block: ok')
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_memory_pool():
|
||||
mp = MemoryPool(device_type='cuda')
|
||||
private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)]
|
||||
mp.allocate(public_block_number=4, private_block_list=private_list)
|
||||
|
||||
block0 = mp.get_public_block()
|
||||
# allocate public blocks
|
||||
mp.allocate_public_blocks(block_num=4)
|
||||
|
||||
# allocate private blocks
|
||||
private_block_specs = [BlockSpec(5, torch.float), BlockSpec(81, torch.float16)]
|
||||
mp.allocate_private_blocks(private_block_specs)
|
||||
|
||||
# test for public blocks
|
||||
block0 = mp.pop_public_block()
|
||||
assert block0 in mp.public_used_blocks
|
||||
assert mp.public_used_cnt == 1
|
||||
assert mp.public_free_cnt == 3
|
||||
|
||||
block1 = mp.get_public_block()
|
||||
assert mp.public_used_count == 1
|
||||
assert mp.public_free_count == 3
|
||||
|
||||
block1 = mp.pop_public_block()
|
||||
assert block1 in mp.public_used_blocks
|
||||
assert mp.public_used_cnt == 2
|
||||
assert mp.public_free_cnt == 2
|
||||
assert mp.public_used_count == 2
|
||||
assert mp.public_free_count == 2
|
||||
|
||||
mp.free_public_block(block0)
|
||||
mp.free_public_block(block1)
|
||||
|
||||
assert block0 in mp.public_free_blocks
|
||||
assert block1 in mp.public_free_blocks
|
||||
assert mp.public_used_cnt == 0
|
||||
assert mp.public_free_cnt == 4
|
||||
assert mp.public_used_count == 0
|
||||
assert mp.public_free_count == 4
|
||||
|
||||
# test for private block
|
||||
block0 = mp.get_private_block(5, torch.float)
|
||||
assert block0.numel == 5
|
||||
assert block0.dtype == torch.float
|
||||
|
||||
print('test_memory_pool: ok')
|
||||
|
||||
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
import colossalai
|
||||
from colossalai.elixir.chunk import BlockSpec, Chunk, MemoryPool, TensorState
|
||||
from colossalai.testing import run_on_environment_flag, spawn
|
||||
|
||||
|
||||
def exam_chunk_functions(nproc, group):
|
||||
|
@ -21,7 +18,7 @@ def exam_chunk_functions(nproc, group):
|
|||
copy_d = d.clone()
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_number=1)
|
||||
mp.allocate_public_blocks(block_num=1)
|
||||
|
||||
chunk = Chunk(mp, 1024, torch.float, group)
|
||||
chunk.l2_norm_flag = True
|
||||
|
@ -43,26 +40,31 @@ def exam_chunk_functions(nproc, group):
|
|||
|
||||
chunk.close_chunk()
|
||||
assert chunk.is_replica is False
|
||||
|
||||
# check function: get_cpu_copy
|
||||
cpu_copys = chunk.get_cpu_copy()
|
||||
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
|
||||
assert t_cpu.device.type == 'cpu'
|
||||
assert torch.equal(t_gpu.cpu(), t_cpu)
|
||||
|
||||
# check function: access_chunk
|
||||
block = mp.get_public_block()
|
||||
block = mp.pop_public_block()
|
||||
chunk.access_chunk(block)
|
||||
assert chunk.is_replica
|
||||
assert chunk.scatter_check
|
||||
check_tensors()
|
||||
|
||||
# check function: release_chunk
|
||||
chunk.optim_sync_flag = False
|
||||
block = chunk.release_chunk()
|
||||
assert block in mp.public_used_blocks
|
||||
assert chunk.is_replica is False
|
||||
assert chunk.optim_sync_flag is True
|
||||
|
||||
# check function: access_chunk after release_chunk
|
||||
chunk.access_chunk(block)
|
||||
check_tensors()
|
||||
|
||||
# check function: reduce_chunk
|
||||
norm = block.payload.float().norm(2)**2
|
||||
chunk.reduce_chunk()
|
||||
|
@ -87,9 +89,10 @@ def exam_chunk_states(nproc, group):
|
|||
d = torch.randn(4, 32, device='cuda')
|
||||
copy_d = d.clone()
|
||||
|
||||
private = [BlockRequire(1024, torch.float)]
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(private_block_list=private)
|
||||
|
||||
private_block_specs = [BlockSpec(1024, torch.float)]
|
||||
mp.allocate_private_blocks(private_block_specs)
|
||||
|
||||
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
|
||||
assert chunk.chunk_size == 1024
|
||||
|
@ -132,23 +135,16 @@ def exam_chunk_states(nproc, group):
|
|||
print('chunk states are ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_functions(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
import copy
|
||||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -8,9 +6,10 @@ import torch.distributed as dist
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing import assert_close
|
||||
|
||||
import colossalai
|
||||
from colossalai.elixir.chunk import ChunkGroup
|
||||
from colossalai.elixir.utils import init_distributed, seed_all
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from colossalai.elixir.utils import seed_all
|
||||
from colossalai.testing import run_on_environment_flag, spawn
|
||||
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
|
||||
from tests.test_elixir.utils import TEST_MODELS, to_cuda
|
||||
|
||||
|
@ -25,7 +24,7 @@ def check_gradient(ddp_model, my_model, cg: ChunkGroup):
|
|||
assert_close(p0.grad.data, p1.data)
|
||||
|
||||
|
||||
def exam_chunk_fetcher(nproc, group):
|
||||
def exam_chunk_fetcher(group):
|
||||
model_fn, data_fn = TEST_MODELS.get('resnet')
|
||||
torch_model = model_fn().cuda()
|
||||
test_model = copy.deepcopy(torch_model)
|
||||
|
@ -49,23 +48,17 @@ def exam_chunk_fetcher(nproc, group):
|
|||
print('private chunk fetcher is ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
exam_chunk_fetcher(group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_fetcher(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chunk_fetcher(world_size=2)
|
||||
test_chunk_fetcher(world_size=2)
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
import colossalai
|
||||
from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool, TensorState
|
||||
from colossalai.testing import run_on_environment_flag, spawn
|
||||
|
||||
|
||||
def exam_chunk_group_functions(nproc, group):
|
||||
def exam_chunk_group_functions(group):
|
||||
a = torch.randn(3, 64, device='cuda')
|
||||
copy_a = a.clone()
|
||||
b = torch.randn(2, 32, device='cuda')
|
||||
|
@ -23,7 +20,9 @@ def exam_chunk_group_functions(nproc, group):
|
|||
copy_e = e.clone()
|
||||
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
|
||||
mp.allocate_public_blocks(block_num=2, block_spec=BlockSpec(numel=256, dtype=torch.float))
|
||||
mp.allocate_private_blocks([BlockSpec(68, torch.float)])
|
||||
|
||||
cg = ChunkGroup(rcache=mp)
|
||||
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
|
||||
c1 = cg.allocate_chunk([c], 256, torch.float, group)
|
||||
|
@ -76,22 +75,15 @@ def exam_chunk_group_functions(nproc, group):
|
|||
print('chunk group functions are ok')
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
exam_chunk_group_functions(group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 2, 4])
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_group(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,19 +1,16 @@
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
import colossalai
|
||||
from colossalai.elixir.chunk import Chunk, MemoryPool
|
||||
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
|
||||
from colossalai.elixir.utils import init_distributed
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from colossalai.testing import spawn
|
||||
|
||||
|
||||
def exam_fifo(nproc, group):
|
||||
def exam_fifo(group):
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate(public_block_number=1)
|
||||
mp.allocate_public_blocks(block_num=1)
|
||||
c0 = Chunk(mp, 1024, torch.float, group)
|
||||
c1 = Chunk(mp, 1024, torch.float, group)
|
||||
c2 = Chunk(mp, 1024, torch.float, group)
|
||||
|
@ -40,9 +37,8 @@ def exam_fifo(nproc, group):
|
|||
assert sdl.top() == c0
|
||||
|
||||
|
||||
def exam_prefetch(nproc, group):
|
||||
def exam_prefetch(group):
|
||||
mp = MemoryPool('cuda')
|
||||
mp.allocate()
|
||||
c0 = Chunk(mp, 1024, torch.float, group)
|
||||
c1 = Chunk(mp, 1024, torch.float, group)
|
||||
c2 = Chunk(mp, 1024, torch.float, group)
|
||||
|
@ -108,22 +104,15 @@ def exam_prefetch(nproc, group):
|
|||
sdl.clear()
|
||||
|
||||
|
||||
def run_dist(rank, world_size):
|
||||
os.environ['RANK'] = str(rank)
|
||||
os.environ['LOCAL_RANK'] = str(rank)
|
||||
os.environ['WORLD_SIZE'] = str(world_size)
|
||||
os.environ['MASTER_ADDR'] = '127.0.0.1'
|
||||
os.environ['MASTER_PORT'] = str(29512)
|
||||
init_distributed()
|
||||
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
exam_fifo(group=dist.GroupMember.WORLD)
|
||||
exam_prefetch(group=dist.GroupMember.WORLD)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_chunk_scheduler(world_size=1):
|
||||
run_func = partial(run_dist, world_size=world_size)
|
||||
torch.multiprocessing.spawn(run_func, nprocs=world_size)
|
||||
spawn(run_dist, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
from colossalai.elixir.ctx import MetaContext
|
||||
from colossalai.testing import run_on_environment_flag
|
||||
from colossalai.elixir.context import MetaContext
|
||||
from tests.test_elixir.utils import TEST_MODELS
|
||||
|
||||
|
||||
@run_on_environment_flag('ELX')
|
||||
def test_meta_context():
|
||||
builder, *_ = TEST_MODELS.get('resnet')
|
||||
with MetaContext():
|
||||
|
|
Loading…
Reference in New Issue