Browse Source

[elixir] refactored the chunk module (#3956)

feature/elixir
Frank Lee 1 year ago committed by GitHub
parent
commit
c173a69b3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      colossalai/elixir/chunk/__init__.py
  2. 6
      colossalai/elixir/chunk/core/__init__.py
  3. 10
      colossalai/elixir/chunk/core/chunk.py
  4. 4
      colossalai/elixir/chunk/core/group.py
  5. 200
      colossalai/elixir/chunk/core/memory_pool.py
  6. 42
      colossalai/elixir/chunk/core/states.py
  7. 3
      colossalai/elixir/context/__init__.py
  8. 34
      colossalai/elixir/context/meta_context.py
  9. 34
      colossalai/elixir/ctx/meta_ctx.py
  10. 4
      colossalai/elixir/kernels/__init__.py
  11. 4
      colossalai/elixir/search/base.py
  12. 12
      tests/test_elixir/test_chunk/fetcher_utils.py
  13. 66
      tests/test_elixir/test_chunk/test_block.py
  14. 36
      tests/test_elixir/test_chunk/test_chunk.py
  15. 25
      tests/test_elixir/test_chunk/test_fetcher.py
  16. 30
      tests/test_elixir/test_chunk/test_group.py
  17. 31
      tests/test_elixir/test_chunk/test_scheduler.py
  18. 4
      tests/test_elixir/test_ctx/test_meta_ctx.py

7
colossalai/elixir/chunk/__init__.py

@ -1,2 +1,7 @@
from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState from .core import BlockSpec, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
from .fetcher import ChunkFetcher from .fetcher import ChunkFetcher
__all__ = [
'BlockSpec', 'Chunk', 'ChunkGroup', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState',
'ChunkFetcher'
]

6
colossalai/elixir/chunk/core/__init__.py

@ -1,4 +1,8 @@
from .chunk import Chunk from .chunk import Chunk
from .group import ChunkGroup from .group import ChunkGroup
from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock from .memory_pool import BlockSpec, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
from .states import TensorState from .states import TensorState
__all__ = [
'Chunk', 'ChunkGroup', 'BlockSpec', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState'
]

10
colossalai/elixir/chunk/core/chunk.py

@ -8,8 +8,8 @@ from torch.distributed import ProcessGroup
from colossalai.elixir.cuda import gpu_device from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.tensor import FakeTensor from colossalai.elixir.tensor import FakeTensor
from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock from .memory_pool import MemoryPool, TensorBlock
from .states import TensorState, ts_update_sanity_check from .states import TensorState, validate_tensor_state_update
class ChunkFullError(Exception): class ChunkFullError(Exception):
@ -383,7 +383,11 @@ class Chunk:
prev_state = self.tensors_info[tensor].state prev_state = self.tensors_info[tensor].state
if prev_state == tensor_state: if prev_state == tensor_state:
return return
if ts_update_sanity_check(prev_state, tensor_state):
# validate whether the update is legal
# if illegal, raise an exception
is_update_valid = validate_tensor_state_update(prev_state, tensor_state, raise_exception=True)
if is_update_valid:
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:

4
colossalai/elixir/chunk/core/group.py

@ -129,7 +129,7 @@ class ChunkGroup(object):
"""Check whether the rcache has enough blocks to store the gathered chunk.""" """Check whether the rcache has enough blocks to store the gathered chunk."""
if chunk.rcache_fused: if chunk.rcache_fused:
return True return True
return self.rcache.public_free_cnt > 0 return self.rcache.public_free_count > 0
def access_chunk(self, chunk: Chunk) -> bool: def access_chunk(self, chunk: Chunk) -> bool:
"""Access a chunk into rCache.""" """Access a chunk into rCache."""
@ -141,7 +141,7 @@ class ChunkGroup(object):
if chunk.rcache_fused: if chunk.rcache_fused:
block = None block = None
else: else:
block = self.rcache.get_public_block() block = self.rcache.pop_public_block()
chunk.access_chunk(block) chunk.access_chunk(block)
self.__add_to_accset(chunk) self.__add_to_accset(chunk)
return True return True

200
colossalai/elixir/chunk/core/memory_pool.py

@ -1,35 +1,52 @@
from abc import ABC from abc import ABC
from collections import defaultdict from collections import defaultdict
from enum import Enum
from typing import Iterable, NamedTuple from typing import Iterable, NamedTuple
import torch import torch
from torch.autograd.profiler_util import _format_memory from torch.autograd.profiler_util import _format_memory
class BlockRequire(NamedTuple): class BlockSpec(NamedTuple):
"""
BlockSpec is the specification of a block. It contains the number of elements and the data type of the block.
Args:
numel (int): the number of elements in the block
dtype (torch.dtype): the data type of the block
"""
numel: int numel: int
dtype: torch.dtype dtype: torch.dtype
class BlockType(Enum):
"""
BlockType is the type of a block. There are two types of blocks: public and private.
"""
PUBLIC = 0
PRIVATE = 1
class TensorBlock(ABC): class TensorBlock(ABC):
"""TensorBlock is the memory unit of memory pool. """
It is a continuous memory block used to store tensors. TensorBlock is the memory unit of memory pool. It is a contiguous memory block used to store tensors.
Each chunk needs a corresponding TensorBlock to store its data during training. Each chunk needs a corresponding TensorBlock to store its data during training.
args: args:
numel: the number of elements in the block size (int): the number of elements in the block
dtype: the data type of the block dtype (torch.dtype): the data type of the block
device_type: the device type of the block device_type (str): the device type of the block
""" """
total_count: int = 0 total_count: int = 0
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: def __init__(self, size: int, dtype: torch.dtype, device_type: str, block_type: BlockType) -> None:
self.block_id = TensorBlock.total_count self.block_id = TensorBlock.total_count
TensorBlock.total_count += 1 TensorBlock.total_count += 1
self.device_type = device_type self.device_type = device_type
self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type) self.payload: torch.Tensor = torch.empty((size,), dtype=dtype, device=device_type)
self.memo_occ: int = self.payload.numel() * self.payload.element_size() self.size_in_bytes: int = self.payload.numel() * self.payload.element_size()
self.block_type = block_type
@property @property
def numel(self): def numel(self):
@ -50,122 +67,145 @@ class TensorBlock(ABC):
return self.block_id == other.block_id return self.block_id == other.block_id
def __repr__(self) -> str: def __repr__(self) -> str:
return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})' return f'{self.block_type}(\n\tID = {self.block_id}, \n\tsize = {self.numel}, \n\tdevice = {self.device_type}, \n\tdtype = {self.dtype}, \n\tsize in bytes={self.size_in_bytes}\n)'
class PublicBlock(TensorBlock): class PublicBlock(TensorBlock):
"""Public blocks have the same length. """
Chunks of the same length can share the same public block. Public blocks have the same length. Chunks of the same length can share the same public block.
""" """
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
super().__init__(numel, dtype, device_type) super().__init__(numel, dtype, device_type, BlockType.PUBLIC)
self.block_type = 'public'
def __repr__(self) -> str:
return f'PublicBlock{super().__repr__()}'
class PrivateBlock(TensorBlock): class PrivateBlock(TensorBlock):
"""Private blocks may have different lengths. """
Each private chunk should use its own private block. Private blocks may have different lengths. Each private chunk should use its own private block.
""" """
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
super().__init__(numel, dtype, device_type) super().__init__(numel, dtype, device_type, BlockType.PRIVATE)
self.block_type = 'private'
def __repr__(self) -> str:
return f'PrivateBlock{super().__repr__()}'
class MemoryPool(object): class MemoryPool(object):
"""A memory pool consists of public blocks and private blocks. """
A memory pool consists of public blocks and private blocks.
rCache uses memory pool to manage memory bolcks. rCache uses memory pool to manage memory bolcks.
Users should allocate memory blocks before using it. Users should allocate memory blocks before using it.
args: args:
device_type: the device type of the memory pool device_type (str): the device type of the memory pool
""" """
def __init__(self, device_type: str) -> None: def __init__(self, device_type: str) -> None:
assert device_type in [
'cuda', 'cpu'
], f'Expected device type to be cuda or cpu, but got an invalid device type: {device_type}'
self.device_type: str = device_type self.device_type: str = device_type
# public space
# public space = number of public block x the public block size in bytes
# all public blocks have the same block size
self.public_space: int = 0 self.public_space: int = 0
self.public_block_size: int = 0 self.public_block_size: int = 0
self.public_dtype: torch.dtype = None self.public_dtype: torch.dtype = None
self.public_free_blocks: list = None # create block holder and counter
self.public_used_blocks: set = None self.public_free_blocks: list = list()
self.public_used_blocks: set = set()
self.public_free_cnt: int = 0 self.public_free_count: int = 0
self.public_used_cnt: int = 0 self.public_used_count: int = 0
# private space
# private look up dict returns an empty list if the block is not found
self.private_space: int = 0 self.private_space: int = 0
self.private_blocks: list = None self.private_blocks: list = list()
self.private_lookup_dict: dict[BlockRequire, list] = None self.private_lookup_dict: dict[BlockSpec, list] = defaultdict(list)
self.__allocate_flag = False # flags for block allcation
self.__public_allocated_flag = False
def allocate(self, self.__private_allocated_flag = False
public_dtype: torch.dtype = torch.float,
public_block_size: int = 1024, def allocate_public_blocks(self, block_num: int, block_spec: BlockSpec = None):
public_block_number: int = 0, """
private_block_list: Iterable[BlockRequire] = ()): Allocate public tensor blocks for the memory pool. This method will allocate public_block_number blocks with size equal to public_block_size.
assert self.__allocate_flag is False """
assert public_block_number >= 0 assert not self.__public_allocated_flag, 'Public blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
assert block_num >= 0, f'Expected public_block_number >= 0, but got {block_num}'
self.public_free_blocks = list()
self.public_used_blocks = set() if block_spec is None:
for _ in range(public_block_number): block_spec = BlockSpec(numel=1024, dtype=torch.float)
block = PublicBlock(public_block_size, public_dtype, self.device_type)
# allocate public blocks
for _ in range(block_num):
block = PublicBlock(block_spec.numel, block_spec.dtype, self.device_type)
self.public_free_blocks.append(block) self.public_free_blocks.append(block)
self.public_space += block.size_in_bytes
if public_block_number <= 0: self.public_free_count += 1
self.public_space = 0
else: # store the block spec info
self.public_space = self.public_free_blocks[0].memo_occ * public_block_number self.public_block_size = block_spec.numel
self.public_block_size = public_block_size self.public_dtype = block_spec.dtype
self.public_dtype = public_dtype
def allocate_private_blocks(self, block_specs: Iterable[BlockSpec]):
self.public_free_cnt = public_block_number """
self.public_used_cnt = 0 Allocate private blocks for the memory pool. This method will allocate private blocks according to the block_specs.
self.private_space = 0 Args:
self.private_blocks = list() block_specs (Iterable[BlockSpec]): the block specs of the private blocks to be allocated
self.private_lookup_dict = defaultdict(list) """
# allocate private blocks
for require in private_block_list: assert not self.__private_allocated_flag, 'Private blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
block = PrivateBlock(require.numel, require.dtype, self.device_type)
self.private_space += block.memo_occ for spec in block_specs:
block = PrivateBlock(spec.numel, spec.dtype, self.device_type)
self.private_space += block.size_in_bytes
self.private_blocks.append(block) self.private_blocks.append(block)
self.private_lookup_dict[require].append(block) self.private_lookup_dict[spec].append(block)
self.__allocate_flag = True self.__private_allocated_flag = True
def __repr__(self) -> str: def __repr__(self) -> str:
return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})' return f'Memory Pool(\n\tpublic_space = {_format_memory(self.public_space)}, \n\tprivate_space={_format_memory(self.private_space)}\n)'
def get_private_block(self, numel: int, dtype: torch.dtype) -> PrivateBlock:
"""
Get a private block with the given numel and dtype.
"""
block_list = self.private_lookup_dict.get(BlockSpec(numel=numel, dtype=dtype))
def get_private_block(self, numel: int, dtype: torch.dtype): if len(block_list) == 0:
block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype)) raise ValueError(f'No private block with numel={numel} and dtype={dtype} is found.')
return block_list.pop() else:
return block_list.pop()
def get_public_block(self): def pop_public_block(self) -> PublicBlock:
self.public_free_cnt -= 1 """
self.public_used_cnt += 1 Get a public block from the memory pool.
"""
self.public_free_count -= 1
self.public_used_count += 1
block = self.public_free_blocks.pop() block = self.public_free_blocks.pop()
self.public_used_blocks.add(block) self.public_used_blocks.add(block)
return block return block
def free_public_block(self, block: TensorBlock): def free_public_block(self, block: TensorBlock) -> PublicBlock:
"""
Free a public block to the memory pool.
Args:
block (TensorBlock): the public block to be freed
"""
assert isinstance(block, PublicBlock) assert isinstance(block, PublicBlock)
assert block in self.public_used_blocks assert block in self.public_used_blocks, f'Cound not find the given block in the used public blocks'
self.public_free_cnt += 1 # update counter
self.public_used_cnt -= 1 self.public_free_count += 1
self.public_used_count -= 1
# update free and used blocks
self.public_used_blocks.remove(block) self.public_used_blocks.remove(block)
self.public_free_blocks.append(block) self.public_free_blocks.append(block)

42
colossalai/elixir/chunk/core/states.py

@ -2,6 +2,10 @@ from enum import Enum
class TensorState(Enum): class TensorState(Enum):
"""
TensorState represents the state of a tensor in Elixir.
There are five states of a tensor: free, compute, hold, hold_after_bwd, ready_for_reduce.
"""
FREE = 0 FREE = 0
COMPUTE = 1 COMPUTE = 1
HOLD = 2 HOLD = 2
@ -9,17 +13,35 @@ class TensorState(Enum):
READY_FOR_REDUCE = 4 READY_FOR_REDUCE = 4
# expected: free -> hold -> compute -> hold -> # this includes the possible state transition in tensor state:
# the item in the list is in the format of (old_state, new_state)
# the complete state transtition is:
# free -> hold -> compute -> hold ->
# -> compute -> hold_after_bwd -> ready_for_reduce # -> compute -> hold_after_bwd -> ready_for_reduce
legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), LEGAL_TENSOR_STATE_UPDATE_LIST = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), (TensorState.COMPUTE, TensorState.HOLD),
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)] (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
def ts_update_sanity_check(old_state, new_state) -> bool: def validate_tensor_state_update(old_state: TensorState, new_state: TensorState, raise_exception: bool = False) -> bool:
if (old_state, new_state) not in legal_ts_update_list: """
raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}') Validate the tensor state update is legal or not.
Args:
old_state (TensorState): the old state of the tensor
new_state (TensorState): the new state of the tensor
raise_exception (bool, optional): whether to raise exception when the state update is illegal. Defaults to False.
Returns:
bool: whether the state update is legal or not.
"""
if (old_state, new_state) not in LEGAL_TENSOR_STATE_UPDATE_LIST:
if raise_exception:
raise RuntimeError(f'Found illegal tensor state updating: {old_state} -> {new_state}')
else:
return False
return True return True

3
colossalai/elixir/context/__init__.py

@ -0,0 +1,3 @@
from .meta_context import MetaContext
__all__ = ['MetaContext']

34
colossalai/elixir/ctx/__init__.py → colossalai/elixir/context/meta_context.py

@ -1,6 +1,6 @@
import torch import torch
tensor_creation_methods = dict(tensor=torch.tensor, TESNOR_CREATION_METHODS = dict(tensor=torch.tensor,
sparse_coo_tensor=torch.sparse_coo_tensor, sparse_coo_tensor=torch.sparse_coo_tensor,
asarray=torch.asarray, asarray=torch.asarray,
as_tensor=torch.as_tensor, as_tensor=torch.as_tensor,
@ -29,4 +29,34 @@ tensor_creation_methods = dict(tensor=torch.tensor,
polar=torch.polar, polar=torch.polar,
heaviside=torch.heaviside) heaviside=torch.heaviside)
from .meta_ctx import MetaContext
# TODO: unify this with lazy init context
class MetaContext(object):
"""A context manager that wraps all tensor creation methods in torch.
By default, all tensors will be created in meta.
args:
device_type: The device type of the tensors to be created.
"""
def __init__(self, device_type: str = 'meta') -> None:
super().__init__()
self.device_type = device_type
return None
def __enter__(self):
def meta_wrap(func):
def wrapped_func(*args, **kwargs):
kwargs['device'] = self.device_type
return func(*args, **kwargs)
return wrapped_func
for name, method in TESNOR_CREATION_METHODS.items():
setattr(torch, name, meta_wrap(method))
def __exit__(self, exc_type, exc_val, exc_tb):
for name, method in TESNOR_CREATION_METHODS.items():
setattr(torch, name, method)

34
colossalai/elixir/ctx/meta_ctx.py

@ -1,34 +0,0 @@
import torch
from colossalai.elixir.ctx import tensor_creation_methods
class MetaContext(object):
"""A context manager that wraps all tensor creation methods in torch.
By default, all tensors will be created in meta.
args:
device_type: The device type of the tensors to be created.
"""
def __init__(self, device_type: str = 'meta') -> None:
super().__init__()
self.device_type = device_type
return None
def __enter__(self):
def meta_wrap(func):
def wrapped_func(*args, **kwargs):
kwargs['device'] = self.device_type
return func(*args, **kwargs)
return wrapped_func
for name, method in tensor_creation_methods.items():
setattr(torch, name, meta_wrap(method))
def __exit__(self, exc_type, exc_val, exc_tb):
for name, method in tensor_creation_methods.items():
setattr(torch, name, method)

4
colossalai/elixir/kernels/__init__.py

@ -1,4 +1,3 @@
import torch
import torch.nn.functional as F import torch.nn.functional as F
fused_torch_functions = {F.layer_norm: F.layer_norm} fused_torch_functions = {F.layer_norm: F.layer_norm}
@ -12,6 +11,3 @@ def register_fused_layer_norm():
except: except:
print('Cannot import fused layer norm, please install apex from source.') print('Cannot import fused layer norm, please install apex from source.')
pass pass
register_fused_layer_norm()

4
colossalai/elixir/search/base.py

@ -5,7 +5,7 @@ from typing import List, Tuple
import torch import torch
import torch.nn as nn import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool
from colossalai.elixir.tracer.param_tracer import generate_tf_order from colossalai.elixir.tracer.param_tracer import generate_tf_order
from colossalai.elixir.tracer.utils import meta_copy from colossalai.elixir.tracer.utils import meta_copy
from colossalai.elixir.utils import print_rank_0 from colossalai.elixir.utils import print_rank_0
@ -119,7 +119,7 @@ class SearchBase(ABC):
for plan in chunk_plans: for plan in chunk_plans:
kwargs = plan.kwargs kwargs = plan.kwargs
if kwargs.get('rcache_fused', False): if kwargs.get('rcache_fused', False):
block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype)) block_require_list.append(BlockSpec(plan.chunk_size, plan.chunk_dtype))
mp = MemoryPool('cuda') mp = MemoryPool('cuda')
mp.allocate(public_dtype=self.unified_dtype, mp.allocate(public_dtype=self.unified_dtype,

12
tests/test_elixir/test_chunk/fetcher_utils.py

@ -4,7 +4,7 @@ import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState from colossalai.elixir.chunk import BlockSpec, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.chunk.scheduler import FIFOScheduler from colossalai.elixir.chunk.scheduler import FIFOScheduler
from colossalai.elixir.hook import BufferStore, HookParam from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.tensor import OutplaceTensor from colossalai.elixir.tensor import OutplaceTensor
@ -32,13 +32,15 @@ def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher)
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo): def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
pg_size = dist.get_world_size(process_group) pg_size = dist.get_world_size(process_group)
private_list = list() mp = MemoryPool('cuda')
# allocate private blocks
private_block_specs = list()
for param in model.parameters(): for param in model.parameters():
block_size = to_divide(param.numel(), pg_size) block_size = to_divide(param.numel(), pg_size)
private_list.append(BlockRequire(block_size, param.dtype)) private_block_specs.append(BlockSpec(block_size, param.dtype))
mp.allocate_private_blocks(private_block_specs)
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private_list)
cg = ChunkGroup(rcache=mp) cg = ChunkGroup(rcache=mp)
# allocate chunk group # allocate chunk group
fused_config = dict(rcache_fused=True) fused_config = dict(rcache_fused=True)

66
tests/test_elixir/test_chunk/test_block.py

@ -1,60 +1,62 @@
import torch import torch
from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock from colossalai.elixir.chunk import BlockSpec, MemoryPool, PrivateBlock, PublicBlock
from colossalai.testing import run_on_environment_flag from colossalai.testing import run_on_environment_flag
@run_on_environment_flag('ELX')
def test_block(): def test_block():
b = PublicBlock(123, torch.float16, 'cuda') # test for public block
payload_b = b.payload public_block = PublicBlock(123, torch.float16, 'cuda')
public_payload = public_block.payload
assert payload_b.numel() == 123
assert payload_b.dtype == torch.float16 assert public_payload.numel() == 123
assert payload_b.device.type == 'cuda' assert public_payload.dtype == torch.float16
assert payload_b.numel() * payload_b.element_size() == b.memo_occ assert public_payload.device.type == 'cuda'
assert public_payload.numel() * public_payload.element_size() == public_block.size_in_bytes
c = PrivateBlock(77, torch.float, 'cpu')
payload_c = c.payload # test for private block
private_block = PrivateBlock(77, torch.float, 'cpu')
assert payload_c.numel() == 77 private_payload = private_block.payload
assert payload_c.dtype == torch.float
assert payload_c.device.type == 'cpu' assert private_payload.numel() == 77
assert payload_c.numel() * payload_c.element_size() == c.memo_occ assert private_payload.dtype == torch.float
assert private_payload.device.type == 'cpu'
assert private_payload.numel() * private_payload.element_size() == private_block.size_in_bytes
print('test_block: ok') print('test_block: ok')
@run_on_environment_flag('ELX')
def test_memory_pool(): def test_memory_pool():
mp = MemoryPool(device_type='cuda') mp = MemoryPool(device_type='cuda')
private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)]
mp.allocate(public_block_number=4, private_block_list=private_list)
block0 = mp.get_public_block() # allocate public blocks
mp.allocate_public_blocks(block_num=4)
assert block0 in mp.public_used_blocks # allocate private blocks
assert mp.public_used_cnt == 1 private_block_specs = [BlockSpec(5, torch.float), BlockSpec(81, torch.float16)]
assert mp.public_free_cnt == 3 mp.allocate_private_blocks(private_block_specs)
block1 = mp.get_public_block() # test for public blocks
block0 = mp.pop_public_block()
assert block0 in mp.public_used_blocks
assert mp.public_used_count == 1
assert mp.public_free_count == 3
block1 = mp.pop_public_block()
assert block1 in mp.public_used_blocks assert block1 in mp.public_used_blocks
assert mp.public_used_cnt == 2 assert mp.public_used_count == 2
assert mp.public_free_cnt == 2 assert mp.public_free_count == 2
mp.free_public_block(block0) mp.free_public_block(block0)
mp.free_public_block(block1) mp.free_public_block(block1)
assert block0 in mp.public_free_blocks assert block0 in mp.public_free_blocks
assert block1 in mp.public_free_blocks assert block1 in mp.public_free_blocks
assert mp.public_used_cnt == 0 assert mp.public_used_count == 0
assert mp.public_free_cnt == 4 assert mp.public_free_count == 4
# test for private block
block0 = mp.get_private_block(5, torch.float) block0 = mp.get_private_block(5, torch.float)
assert block0.numel == 5 assert block0.numel == 5
assert block0.dtype == torch.float assert block0.dtype == torch.float
print('test_memory_pool: ok') print('test_memory_pool: ok')

36
tests/test_elixir/test_chunk/test_chunk.py

@ -1,13 +1,10 @@
import os
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState import colossalai
from colossalai.elixir.utils import init_distributed from colossalai.elixir.chunk import BlockSpec, Chunk, MemoryPool, TensorState
from colossalai.testing import run_on_environment_flag from colossalai.testing import run_on_environment_flag, spawn
def exam_chunk_functions(nproc, group): def exam_chunk_functions(nproc, group):
@ -21,7 +18,7 @@ def exam_chunk_functions(nproc, group):
copy_d = d.clone() copy_d = d.clone()
mp = MemoryPool('cuda') mp = MemoryPool('cuda')
mp.allocate(public_block_number=1) mp.allocate_public_blocks(block_num=1)
chunk = Chunk(mp, 1024, torch.float, group) chunk = Chunk(mp, 1024, torch.float, group)
chunk.l2_norm_flag = True chunk.l2_norm_flag = True
@ -43,26 +40,31 @@ def exam_chunk_functions(nproc, group):
chunk.close_chunk() chunk.close_chunk()
assert chunk.is_replica is False assert chunk.is_replica is False
# check function: get_cpu_copy # check function: get_cpu_copy
cpu_copys = chunk.get_cpu_copy() cpu_copys = chunk.get_cpu_copy()
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys): for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
assert t_cpu.device.type == 'cpu' assert t_cpu.device.type == 'cpu'
assert torch.equal(t_gpu.cpu(), t_cpu) assert torch.equal(t_gpu.cpu(), t_cpu)
# check function: access_chunk # check function: access_chunk
block = mp.get_public_block() block = mp.pop_public_block()
chunk.access_chunk(block) chunk.access_chunk(block)
assert chunk.is_replica assert chunk.is_replica
assert chunk.scatter_check assert chunk.scatter_check
check_tensors() check_tensors()
# check function: release_chunk # check function: release_chunk
chunk.optim_sync_flag = False chunk.optim_sync_flag = False
block = chunk.release_chunk() block = chunk.release_chunk()
assert block in mp.public_used_blocks assert block in mp.public_used_blocks
assert chunk.is_replica is False assert chunk.is_replica is False
assert chunk.optim_sync_flag is True assert chunk.optim_sync_flag is True
# check function: access_chunk after release_chunk # check function: access_chunk after release_chunk
chunk.access_chunk(block) chunk.access_chunk(block)
check_tensors() check_tensors()
# check function: reduce_chunk # check function: reduce_chunk
norm = block.payload.float().norm(2)**2 norm = block.payload.float().norm(2)**2
chunk.reduce_chunk() chunk.reduce_chunk()
@ -87,9 +89,10 @@ def exam_chunk_states(nproc, group):
d = torch.randn(4, 32, device='cuda') d = torch.randn(4, 32, device='cuda')
copy_d = d.clone() copy_d = d.clone()
private = [BlockRequire(1024, torch.float)]
mp = MemoryPool('cuda') mp = MemoryPool('cuda')
mp.allocate(private_block_list=private)
private_block_specs = [BlockSpec(1024, torch.float)]
mp.allocate_private_blocks(private_block_specs)
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True) chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
assert chunk.chunk_size == 1024 assert chunk.chunk_size == 1024
@ -132,23 +135,16 @@ def exam_chunk_states(nproc, group):
print('chunk states are ok') print('chunk states are ok')
def run_dist(rank, world_size): def run_dist(rank, world_size, port):
os.environ['RANK'] = str(rank) colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD) exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD) exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4]) @pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_functions(world_size): def test_chunk_functions(world_size):
run_func = partial(run_dist, world_size=world_size) spawn(run_dist, nprocs=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

25
tests/test_elixir/test_chunk/test_fetcher.py

@ -1,6 +1,4 @@
import copy import copy
import os
from functools import partial
import pytest import pytest
import torch import torch
@ -8,9 +6,10 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close from torch.testing import assert_close
import colossalai
from colossalai.elixir.chunk import ChunkGroup from colossalai.elixir.chunk import ChunkGroup
from colossalai.elixir.utils import init_distributed, seed_all from colossalai.elixir.utils import seed_all
from colossalai.testing import run_on_environment_flag from colossalai.testing import run_on_environment_flag, spawn
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
from tests.test_elixir.utils import TEST_MODELS, to_cuda from tests.test_elixir.utils import TEST_MODELS, to_cuda
@ -25,7 +24,7 @@ def check_gradient(ddp_model, my_model, cg: ChunkGroup):
assert_close(p0.grad.data, p1.data) assert_close(p0.grad.data, p1.data)
def exam_chunk_fetcher(nproc, group): def exam_chunk_fetcher(group):
model_fn, data_fn = TEST_MODELS.get('resnet') model_fn, data_fn = TEST_MODELS.get('resnet')
torch_model = model_fn().cuda() torch_model = model_fn().cuda()
test_model = copy.deepcopy(torch_model) test_model = copy.deepcopy(torch_model)
@ -49,23 +48,17 @@ def exam_chunk_fetcher(nproc, group):
print('private chunk fetcher is ok') print('private chunk fetcher is ok')
def run_dist(rank, world_size): def run_dist(rank, world_size, port):
os.environ['RANK'] = str(rank) colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
os.environ['LOCAL_RANK'] = str(rank) exam_chunk_fetcher(group=dist.GroupMember.WORLD)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4]) @pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_fetcher(world_size): def test_chunk_fetcher(world_size):
run_func = partial(run_dist, world_size=world_size) spawn(run_dist, nprocs=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_chunk_fetcher(world_size=2) test_chunk_fetcher(world_size=2)
test_chunk_fetcher(world_size=2)

30
tests/test_elixir/test_chunk/test_group.py

@ -1,16 +1,13 @@
import os
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState import colossalai
from colossalai.elixir.utils import init_distributed from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool, TensorState
from colossalai.testing import run_on_environment_flag from colossalai.testing import run_on_environment_flag, spawn
def exam_chunk_group_functions(nproc, group): def exam_chunk_group_functions(group):
a = torch.randn(3, 64, device='cuda') a = torch.randn(3, 64, device='cuda')
copy_a = a.clone() copy_a = a.clone()
b = torch.randn(2, 32, device='cuda') b = torch.randn(2, 32, device='cuda')
@ -23,7 +20,9 @@ def exam_chunk_group_functions(nproc, group):
copy_e = e.clone() copy_e = e.clone()
mp = MemoryPool('cuda') mp = MemoryPool('cuda')
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)]) mp.allocate_public_blocks(block_num=2, block_spec=BlockSpec(numel=256, dtype=torch.float))
mp.allocate_private_blocks([BlockSpec(68, torch.float)])
cg = ChunkGroup(rcache=mp) cg = ChunkGroup(rcache=mp)
c0 = cg.allocate_chunk([a, b], 256, torch.float, group) c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
c1 = cg.allocate_chunk([c], 256, torch.float, group) c1 = cg.allocate_chunk([c], 256, torch.float, group)
@ -76,22 +75,15 @@ def exam_chunk_group_functions(nproc, group):
print('chunk group functions are ok') print('chunk group functions are ok')
def run_dist(rank, world_size): def run_dist(rank, world_size, port):
os.environ['RANK'] = str(rank) colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
os.environ['LOCAL_RANK'] = str(rank) exam_chunk_group_functions(group=dist.GroupMember.WORLD)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4]) @pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_group(world_size): def test_chunk_group(world_size):
run_func = partial(run_dist, world_size=world_size) spawn(run_dist, nprocs=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

31
tests/test_elixir/test_chunk/test_scheduler.py

@ -1,19 +1,16 @@
import os
from functools import partial
import pytest import pytest
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import colossalai
from colossalai.elixir.chunk import Chunk, MemoryPool from colossalai.elixir.chunk import Chunk, MemoryPool
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
from colossalai.elixir.utils import init_distributed from colossalai.testing import spawn
from colossalai.testing import run_on_environment_flag
def exam_fifo(nproc, group): def exam_fifo(group):
mp = MemoryPool('cuda') mp = MemoryPool('cuda')
mp.allocate(public_block_number=1) mp.allocate_public_blocks(block_num=1)
c0 = Chunk(mp, 1024, torch.float, group) c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group) c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group) c2 = Chunk(mp, 1024, torch.float, group)
@ -40,9 +37,8 @@ def exam_fifo(nproc, group):
assert sdl.top() == c0 assert sdl.top() == c0
def exam_prefetch(nproc, group): def exam_prefetch(group):
mp = MemoryPool('cuda') mp = MemoryPool('cuda')
mp.allocate()
c0 = Chunk(mp, 1024, torch.float, group) c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group) c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group) c2 = Chunk(mp, 1024, torch.float, group)
@ -108,22 +104,15 @@ def exam_prefetch(nproc, group):
sdl.clear() sdl.clear()
def run_dist(rank, world_size): def run_dist(rank, world_size, port):
os.environ['RANK'] = str(rank) colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
os.environ['LOCAL_RANK'] = str(rank) exam_fifo(group=dist.GroupMember.WORLD)
os.environ['WORLD_SIZE'] = str(world_size) exam_prefetch(group=dist.GroupMember.WORLD)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist @pytest.mark.dist
@run_on_environment_flag('ELX')
def test_chunk_scheduler(world_size=1): def test_chunk_scheduler(world_size=1):
run_func = partial(run_dist, world_size=world_size) spawn(run_dist, nprocs=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':

4
tests/test_elixir/test_ctx/test_meta_ctx.py

@ -1,9 +1,7 @@
from colossalai.elixir.ctx import MetaContext from colossalai.elixir.context import MetaContext
from colossalai.testing import run_on_environment_flag
from tests.test_elixir.utils import TEST_MODELS from tests.test_elixir.utils import TEST_MODELS
@run_on_environment_flag('ELX')
def test_meta_context(): def test_meta_context():
builder, *_ = TEST_MODELS.get('resnet') builder, *_ = TEST_MODELS.get('resnet')
with MetaContext(): with MetaContext():

Loading…
Cancel
Save