Browse Source

[elixir] refactored the chunk module (#3956)

feature/elixir
Frank Lee 1 year ago committed by GitHub
parent
commit
c173a69b3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      colossalai/elixir/chunk/__init__.py
  2. 6
      colossalai/elixir/chunk/core/__init__.py
  3. 10
      colossalai/elixir/chunk/core/chunk.py
  4. 4
      colossalai/elixir/chunk/core/group.py
  5. 200
      colossalai/elixir/chunk/core/memory_pool.py
  6. 42
      colossalai/elixir/chunk/core/states.py
  7. 3
      colossalai/elixir/context/__init__.py
  8. 34
      colossalai/elixir/context/meta_context.py
  9. 34
      colossalai/elixir/ctx/meta_ctx.py
  10. 4
      colossalai/elixir/kernels/__init__.py
  11. 4
      colossalai/elixir/search/base.py
  12. 12
      tests/test_elixir/test_chunk/fetcher_utils.py
  13. 66
      tests/test_elixir/test_chunk/test_block.py
  14. 36
      tests/test_elixir/test_chunk/test_chunk.py
  15. 25
      tests/test_elixir/test_chunk/test_fetcher.py
  16. 30
      tests/test_elixir/test_chunk/test_group.py
  17. 31
      tests/test_elixir/test_chunk/test_scheduler.py
  18. 4
      tests/test_elixir/test_ctx/test_meta_ctx.py

7
colossalai/elixir/chunk/__init__.py

@ -1,2 +1,7 @@
from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
from .core import BlockSpec, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState
from .fetcher import ChunkFetcher
__all__ = [
'BlockSpec', 'Chunk', 'ChunkGroup', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState',
'ChunkFetcher'
]

6
colossalai/elixir/chunk/core/__init__.py

@ -1,4 +1,8 @@
from .chunk import Chunk
from .group import ChunkGroup
from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
from .memory_pool import BlockSpec, MemoryPool, PrivateBlock, PublicBlock, TensorBlock
from .states import TensorState
__all__ = [
'Chunk', 'ChunkGroup', 'BlockSpec', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState'
]

10
colossalai/elixir/chunk/core/chunk.py

@ -8,8 +8,8 @@ from torch.distributed import ProcessGroup
from colossalai.elixir.cuda import gpu_device
from colossalai.elixir.tensor import FakeTensor
from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock
from .states import TensorState, ts_update_sanity_check
from .memory_pool import MemoryPool, TensorBlock
from .states import TensorState, validate_tensor_state_update
class ChunkFullError(Exception):
@ -383,7 +383,11 @@ class Chunk:
prev_state = self.tensors_info[tensor].state
if prev_state == tensor_state:
return
if ts_update_sanity_check(prev_state, tensor_state):
# validate whether the update is legal
# if illegal, raise an exception
is_update_valid = validate_tensor_state_update(prev_state, tensor_state, raise_exception=True)
if is_update_valid:
self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state)
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:

4
colossalai/elixir/chunk/core/group.py

@ -129,7 +129,7 @@ class ChunkGroup(object):
"""Check whether the rcache has enough blocks to store the gathered chunk."""
if chunk.rcache_fused:
return True
return self.rcache.public_free_cnt > 0
return self.rcache.public_free_count > 0
def access_chunk(self, chunk: Chunk) -> bool:
"""Access a chunk into rCache."""
@ -141,7 +141,7 @@ class ChunkGroup(object):
if chunk.rcache_fused:
block = None
else:
block = self.rcache.get_public_block()
block = self.rcache.pop_public_block()
chunk.access_chunk(block)
self.__add_to_accset(chunk)
return True

200
colossalai/elixir/chunk/core/memory_pool.py

@ -1,35 +1,52 @@
from abc import ABC
from collections import defaultdict
from enum import Enum
from typing import Iterable, NamedTuple
import torch
from torch.autograd.profiler_util import _format_memory
class BlockRequire(NamedTuple):
class BlockSpec(NamedTuple):
"""
BlockSpec is the specification of a block. It contains the number of elements and the data type of the block.
Args:
numel (int): the number of elements in the block
dtype (torch.dtype): the data type of the block
"""
numel: int
dtype: torch.dtype
class BlockType(Enum):
"""
BlockType is the type of a block. There are two types of blocks: public and private.
"""
PUBLIC = 0
PRIVATE = 1
class TensorBlock(ABC):
"""TensorBlock is the memory unit of memory pool.
It is a continuous memory block used to store tensors.
"""
TensorBlock is the memory unit of memory pool. It is a contiguous memory block used to store tensors.
Each chunk needs a corresponding TensorBlock to store its data during training.
args:
numel: the number of elements in the block
dtype: the data type of the block
device_type: the device type of the block
size (int): the number of elements in the block
dtype (torch.dtype): the data type of the block
device_type (str): the device type of the block
"""
total_count: int = 0
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
def __init__(self, size: int, dtype: torch.dtype, device_type: str, block_type: BlockType) -> None:
self.block_id = TensorBlock.total_count
TensorBlock.total_count += 1
self.device_type = device_type
self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type)
self.memo_occ: int = self.payload.numel() * self.payload.element_size()
self.payload: torch.Tensor = torch.empty((size,), dtype=dtype, device=device_type)
self.size_in_bytes: int = self.payload.numel() * self.payload.element_size()
self.block_type = block_type
@property
def numel(self):
@ -50,122 +67,145 @@ class TensorBlock(ABC):
return self.block_id == other.block_id
def __repr__(self) -> str:
return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})'
return f'{self.block_type}(\n\tID = {self.block_id}, \n\tsize = {self.numel}, \n\tdevice = {self.device_type}, \n\tdtype = {self.dtype}, \n\tsize in bytes={self.size_in_bytes}\n)'
class PublicBlock(TensorBlock):
"""Public blocks have the same length.
Chunks of the same length can share the same public block.
"""
Public blocks have the same length. Chunks of the same length can share the same public block.
"""
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
super().__init__(numel, dtype, device_type)
self.block_type = 'public'
def __repr__(self) -> str:
return f'PublicBlock{super().__repr__()}'
super().__init__(numel, dtype, device_type, BlockType.PUBLIC)
class PrivateBlock(TensorBlock):
"""Private blocks may have different lengths.
Each private chunk should use its own private block.
"""
Private blocks may have different lengths. Each private chunk should use its own private block.
"""
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None:
super().__init__(numel, dtype, device_type)
self.block_type = 'private'
def __repr__(self) -> str:
return f'PrivateBlock{super().__repr__()}'
super().__init__(numel, dtype, device_type, BlockType.PRIVATE)
class MemoryPool(object):
"""A memory pool consists of public blocks and private blocks.
"""
A memory pool consists of public blocks and private blocks.
rCache uses memory pool to manage memory bolcks.
Users should allocate memory blocks before using it.
args:
device_type: the device type of the memory pool
device_type (str): the device type of the memory pool
"""
def __init__(self, device_type: str) -> None:
assert device_type in [
'cuda', 'cpu'
], f'Expected device type to be cuda or cpu, but got an invalid device type: {device_type}'
self.device_type: str = device_type
# public space
# public space = number of public block x the public block size in bytes
# all public blocks have the same block size
self.public_space: int = 0
self.public_block_size: int = 0
self.public_dtype: torch.dtype = None
self.public_free_blocks: list = None
self.public_used_blocks: set = None
self.public_free_cnt: int = 0
self.public_used_cnt: int = 0
# create block holder and counter
self.public_free_blocks: list = list()
self.public_used_blocks: set = set()
self.public_free_count: int = 0
self.public_used_count: int = 0
# private space
# private look up dict returns an empty list if the block is not found
self.private_space: int = 0
self.private_blocks: list = None
self.private_lookup_dict: dict[BlockRequire, list] = None
self.__allocate_flag = False
def allocate(self,
public_dtype: torch.dtype = torch.float,
public_block_size: int = 1024,
public_block_number: int = 0,
private_block_list: Iterable[BlockRequire] = ()):
assert self.__allocate_flag is False
assert public_block_number >= 0
self.public_free_blocks = list()
self.public_used_blocks = set()
for _ in range(public_block_number):
block = PublicBlock(public_block_size, public_dtype, self.device_type)
self.private_blocks: list = list()
self.private_lookup_dict: dict[BlockSpec, list] = defaultdict(list)
# flags for block allcation
self.__public_allocated_flag = False
self.__private_allocated_flag = False
def allocate_public_blocks(self, block_num: int, block_spec: BlockSpec = None):
"""
Allocate public tensor blocks for the memory pool. This method will allocate public_block_number blocks with size equal to public_block_size.
"""
assert not self.__public_allocated_flag, 'Public blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
assert block_num >= 0, f'Expected public_block_number >= 0, but got {block_num}'
if block_spec is None:
block_spec = BlockSpec(numel=1024, dtype=torch.float)
# allocate public blocks
for _ in range(block_num):
block = PublicBlock(block_spec.numel, block_spec.dtype, self.device_type)
self.public_free_blocks.append(block)
if public_block_number <= 0:
self.public_space = 0
else:
self.public_space = self.public_free_blocks[0].memo_occ * public_block_number
self.public_block_size = public_block_size
self.public_dtype = public_dtype
self.public_free_cnt = public_block_number
self.public_used_cnt = 0
self.private_space = 0
self.private_blocks = list()
self.private_lookup_dict = defaultdict(list)
for require in private_block_list:
block = PrivateBlock(require.numel, require.dtype, self.device_type)
self.private_space += block.memo_occ
self.public_space += block.size_in_bytes
self.public_free_count += 1
# store the block spec info
self.public_block_size = block_spec.numel
self.public_dtype = block_spec.dtype
def allocate_private_blocks(self, block_specs: Iterable[BlockSpec]):
"""
Allocate private blocks for the memory pool. This method will allocate private blocks according to the block_specs.
Args:
block_specs (Iterable[BlockSpec]): the block specs of the private blocks to be allocated
"""
# allocate private blocks
assert not self.__private_allocated_flag, 'Private blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.'
for spec in block_specs:
block = PrivateBlock(spec.numel, spec.dtype, self.device_type)
self.private_space += block.size_in_bytes
self.private_blocks.append(block)
self.private_lookup_dict[require].append(block)
self.private_lookup_dict[spec].append(block)
self.__allocate_flag = True
self.__private_allocated_flag = True
def __repr__(self) -> str:
return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})'
return f'Memory Pool(\n\tpublic_space = {_format_memory(self.public_space)}, \n\tprivate_space={_format_memory(self.private_space)}\n)'
def get_private_block(self, numel: int, dtype: torch.dtype) -> PrivateBlock:
"""
Get a private block with the given numel and dtype.
"""
block_list = self.private_lookup_dict.get(BlockSpec(numel=numel, dtype=dtype))
def get_private_block(self, numel: int, dtype: torch.dtype):
block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype))
return block_list.pop()
if len(block_list) == 0:
raise ValueError(f'No private block with numel={numel} and dtype={dtype} is found.')
else:
return block_list.pop()
def get_public_block(self):
self.public_free_cnt -= 1
self.public_used_cnt += 1
def pop_public_block(self) -> PublicBlock:
"""
Get a public block from the memory pool.
"""
self.public_free_count -= 1
self.public_used_count += 1
block = self.public_free_blocks.pop()
self.public_used_blocks.add(block)
return block
def free_public_block(self, block: TensorBlock):
def free_public_block(self, block: TensorBlock) -> PublicBlock:
"""
Free a public block to the memory pool.
Args:
block (TensorBlock): the public block to be freed
"""
assert isinstance(block, PublicBlock)
assert block in self.public_used_blocks
assert block in self.public_used_blocks, f'Cound not find the given block in the used public blocks'
self.public_free_cnt += 1
self.public_used_cnt -= 1
# update counter
self.public_free_count += 1
self.public_used_count -= 1
# update free and used blocks
self.public_used_blocks.remove(block)
self.public_free_blocks.append(block)

42
colossalai/elixir/chunk/core/states.py

@ -2,6 +2,10 @@ from enum import Enum
class TensorState(Enum):
"""
TensorState represents the state of a tensor in Elixir.
There are five states of a tensor: free, compute, hold, hold_after_bwd, ready_for_reduce.
"""
FREE = 0
COMPUTE = 1
HOLD = 2
@ -9,17 +13,35 @@ class TensorState(Enum):
READY_FOR_REDUCE = 4
# expected: free -> hold -> compute -> hold ->
# this includes the possible state transition in tensor state:
# the item in the list is in the format of (old_state, new_state)
# the complete state transtition is:
# free -> hold -> compute -> hold ->
# -> compute -> hold_after_bwd -> ready_for_reduce
legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
LEGAL_TENSOR_STATE_UPDATE_LIST = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
(TensorState.COMPUTE, TensorState.HOLD),
(TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
(TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE),
(TensorState.READY_FOR_REDUCE, TensorState.HOLD)]
def ts_update_sanity_check(old_state, new_state) -> bool:
if (old_state, new_state) not in legal_ts_update_list:
raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}')
def validate_tensor_state_update(old_state: TensorState, new_state: TensorState, raise_exception: bool = False) -> bool:
"""
Validate the tensor state update is legal or not.
Args:
old_state (TensorState): the old state of the tensor
new_state (TensorState): the new state of the tensor
raise_exception (bool, optional): whether to raise exception when the state update is illegal. Defaults to False.
Returns:
bool: whether the state update is legal or not.
"""
if (old_state, new_state) not in LEGAL_TENSOR_STATE_UPDATE_LIST:
if raise_exception:
raise RuntimeError(f'Found illegal tensor state updating: {old_state} -> {new_state}')
else:
return False
return True

3
colossalai/elixir/context/__init__.py

@ -0,0 +1,3 @@
from .meta_context import MetaContext
__all__ = ['MetaContext']

34
colossalai/elixir/ctx/__init__.py → colossalai/elixir/context/meta_context.py

@ -1,6 +1,6 @@
import torch
tensor_creation_methods = dict(tensor=torch.tensor,
TESNOR_CREATION_METHODS = dict(tensor=torch.tensor,
sparse_coo_tensor=torch.sparse_coo_tensor,
asarray=torch.asarray,
as_tensor=torch.as_tensor,
@ -29,4 +29,34 @@ tensor_creation_methods = dict(tensor=torch.tensor,
polar=torch.polar,
heaviside=torch.heaviside)
from .meta_ctx import MetaContext
# TODO: unify this with lazy init context
class MetaContext(object):
"""A context manager that wraps all tensor creation methods in torch.
By default, all tensors will be created in meta.
args:
device_type: The device type of the tensors to be created.
"""
def __init__(self, device_type: str = 'meta') -> None:
super().__init__()
self.device_type = device_type
return None
def __enter__(self):
def meta_wrap(func):
def wrapped_func(*args, **kwargs):
kwargs['device'] = self.device_type
return func(*args, **kwargs)
return wrapped_func
for name, method in TESNOR_CREATION_METHODS.items():
setattr(torch, name, meta_wrap(method))
def __exit__(self, exc_type, exc_val, exc_tb):
for name, method in TESNOR_CREATION_METHODS.items():
setattr(torch, name, method)

34
colossalai/elixir/ctx/meta_ctx.py

@ -1,34 +0,0 @@
import torch
from colossalai.elixir.ctx import tensor_creation_methods
class MetaContext(object):
"""A context manager that wraps all tensor creation methods in torch.
By default, all tensors will be created in meta.
args:
device_type: The device type of the tensors to be created.
"""
def __init__(self, device_type: str = 'meta') -> None:
super().__init__()
self.device_type = device_type
return None
def __enter__(self):
def meta_wrap(func):
def wrapped_func(*args, **kwargs):
kwargs['device'] = self.device_type
return func(*args, **kwargs)
return wrapped_func
for name, method in tensor_creation_methods.items():
setattr(torch, name, meta_wrap(method))
def __exit__(self, exc_type, exc_val, exc_tb):
for name, method in tensor_creation_methods.items():
setattr(torch, name, method)

4
colossalai/elixir/kernels/__init__.py

@ -1,4 +1,3 @@
import torch
import torch.nn.functional as F
fused_torch_functions = {F.layer_norm: F.layer_norm}
@ -12,6 +11,3 @@ def register_fused_layer_norm():
except:
print('Cannot import fused layer norm, please install apex from source.')
pass
register_fused_layer_norm()

4
colossalai/elixir/search/base.py

@ -5,7 +5,7 @@ from typing import List, Tuple
import torch
import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool
from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool
from colossalai.elixir.tracer.param_tracer import generate_tf_order
from colossalai.elixir.tracer.utils import meta_copy
from colossalai.elixir.utils import print_rank_0
@ -119,7 +119,7 @@ class SearchBase(ABC):
for plan in chunk_plans:
kwargs = plan.kwargs
if kwargs.get('rcache_fused', False):
block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype))
block_require_list.append(BlockSpec(plan.chunk_size, plan.chunk_dtype))
mp = MemoryPool('cuda')
mp.allocate(public_dtype=self.unified_dtype,

12
tests/test_elixir/test_chunk/fetcher_utils.py

@ -4,7 +4,7 @@ import torch
import torch.distributed as dist
import torch.nn as nn
from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.chunk import BlockSpec, ChunkFetcher, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.chunk.scheduler import FIFOScheduler
from colossalai.elixir.hook import BufferStore, HookParam
from colossalai.elixir.tensor import OutplaceTensor
@ -32,13 +32,15 @@ def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher)
def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo):
pg_size = dist.get_world_size(process_group)
private_list = list()
mp = MemoryPool('cuda')
# allocate private blocks
private_block_specs = list()
for param in model.parameters():
block_size = to_divide(param.numel(), pg_size)
private_list.append(BlockRequire(block_size, param.dtype))
private_block_specs.append(BlockSpec(block_size, param.dtype))
mp.allocate_private_blocks(private_block_specs)
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private_list)
cg = ChunkGroup(rcache=mp)
# allocate chunk group
fused_config = dict(rcache_fused=True)

66
tests/test_elixir/test_chunk/test_block.py

@ -1,60 +1,62 @@
import torch
from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock
from colossalai.elixir.chunk import BlockSpec, MemoryPool, PrivateBlock, PublicBlock
from colossalai.testing import run_on_environment_flag
@run_on_environment_flag('ELX')
def test_block():
b = PublicBlock(123, torch.float16, 'cuda')
payload_b = b.payload
assert payload_b.numel() == 123
assert payload_b.dtype == torch.float16
assert payload_b.device.type == 'cuda'
assert payload_b.numel() * payload_b.element_size() == b.memo_occ
c = PrivateBlock(77, torch.float, 'cpu')
payload_c = c.payload
assert payload_c.numel() == 77
assert payload_c.dtype == torch.float
assert payload_c.device.type == 'cpu'
assert payload_c.numel() * payload_c.element_size() == c.memo_occ
# test for public block
public_block = PublicBlock(123, torch.float16, 'cuda')
public_payload = public_block.payload
assert public_payload.numel() == 123
assert public_payload.dtype == torch.float16
assert public_payload.device.type == 'cuda'
assert public_payload.numel() * public_payload.element_size() == public_block.size_in_bytes
# test for private block
private_block = PrivateBlock(77, torch.float, 'cpu')
private_payload = private_block.payload
assert private_payload.numel() == 77
assert private_payload.dtype == torch.float
assert private_payload.device.type == 'cpu'
assert private_payload.numel() * private_payload.element_size() == private_block.size_in_bytes
print('test_block: ok')
@run_on_environment_flag('ELX')
def test_memory_pool():
mp = MemoryPool(device_type='cuda')
private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)]
mp.allocate(public_block_number=4, private_block_list=private_list)
block0 = mp.get_public_block()
# allocate public blocks
mp.allocate_public_blocks(block_num=4)
assert block0 in mp.public_used_blocks
assert mp.public_used_cnt == 1
assert mp.public_free_cnt == 3
# allocate private blocks
private_block_specs = [BlockSpec(5, torch.float), BlockSpec(81, torch.float16)]
mp.allocate_private_blocks(private_block_specs)
block1 = mp.get_public_block()
# test for public blocks
block0 = mp.pop_public_block()
assert block0 in mp.public_used_blocks
assert mp.public_used_count == 1
assert mp.public_free_count == 3
block1 = mp.pop_public_block()
assert block1 in mp.public_used_blocks
assert mp.public_used_cnt == 2
assert mp.public_free_cnt == 2
assert mp.public_used_count == 2
assert mp.public_free_count == 2
mp.free_public_block(block0)
mp.free_public_block(block1)
assert block0 in mp.public_free_blocks
assert block1 in mp.public_free_blocks
assert mp.public_used_cnt == 0
assert mp.public_free_cnt == 4
assert mp.public_used_count == 0
assert mp.public_free_count == 4
# test for private block
block0 = mp.get_private_block(5, torch.float)
assert block0.numel == 5
assert block0.dtype == torch.float
print('test_memory_pool: ok')

36
tests/test_elixir/test_chunk/test_chunk.py

@ -1,13 +1,10 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
import colossalai
from colossalai.elixir.chunk import BlockSpec, Chunk, MemoryPool, TensorState
from colossalai.testing import run_on_environment_flag, spawn
def exam_chunk_functions(nproc, group):
@ -21,7 +18,7 @@ def exam_chunk_functions(nproc, group):
copy_d = d.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_number=1)
mp.allocate_public_blocks(block_num=1)
chunk = Chunk(mp, 1024, torch.float, group)
chunk.l2_norm_flag = True
@ -43,26 +40,31 @@ def exam_chunk_functions(nproc, group):
chunk.close_chunk()
assert chunk.is_replica is False
# check function: get_cpu_copy
cpu_copys = chunk.get_cpu_copy()
for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys):
assert t_cpu.device.type == 'cpu'
assert torch.equal(t_gpu.cpu(), t_cpu)
# check function: access_chunk
block = mp.get_public_block()
block = mp.pop_public_block()
chunk.access_chunk(block)
assert chunk.is_replica
assert chunk.scatter_check
check_tensors()
# check function: release_chunk
chunk.optim_sync_flag = False
block = chunk.release_chunk()
assert block in mp.public_used_blocks
assert chunk.is_replica is False
assert chunk.optim_sync_flag is True
# check function: access_chunk after release_chunk
chunk.access_chunk(block)
check_tensors()
# check function: reduce_chunk
norm = block.payload.float().norm(2)**2
chunk.reduce_chunk()
@ -87,9 +89,10 @@ def exam_chunk_states(nproc, group):
d = torch.randn(4, 32, device='cuda')
copy_d = d.clone()
private = [BlockRequire(1024, torch.float)]
mp = MemoryPool('cuda')
mp.allocate(private_block_list=private)
private_block_specs = [BlockSpec(1024, torch.float)]
mp.allocate_private_blocks(private_block_specs)
chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True)
assert chunk.chunk_size == 1024
@ -132,23 +135,16 @@ def exam_chunk_states(nproc, group):
print('chunk states are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD)
exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_functions(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
spawn(run_dist, nprocs=world_size)
if __name__ == '__main__':

25
tests/test_elixir/test_chunk/test_fetcher.py

@ -1,6 +1,4 @@
import copy
import os
from functools import partial
import pytest
import torch
@ -8,9 +6,10 @@ import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.elixir.chunk import ChunkGroup
from colossalai.elixir.utils import init_distributed, seed_all
from colossalai.testing import run_on_environment_flag
from colossalai.elixir.utils import seed_all
from colossalai.testing import run_on_environment_flag, spawn
from tests.test_elixir.test_chunk.fetcher_utils import hook_transform
from tests.test_elixir.utils import TEST_MODELS, to_cuda
@ -25,7 +24,7 @@ def check_gradient(ddp_model, my_model, cg: ChunkGroup):
assert_close(p0.grad.data, p1.data)
def exam_chunk_fetcher(nproc, group):
def exam_chunk_fetcher(group):
model_fn, data_fn = TEST_MODELS.get('resnet')
torch_model = model_fn().cuda()
test_model = copy.deepcopy(torch_model)
@ -49,23 +48,17 @@ def exam_chunk_fetcher(nproc, group):
print('private chunk fetcher is ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_chunk_fetcher(group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_fetcher(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
spawn(run_dist, nprocs=world_size)
if __name__ == '__main__':
test_chunk_fetcher(world_size=2)
test_chunk_fetcher(world_size=2)

30
tests/test_elixir/test_chunk/test_group.py

@ -1,16 +1,13 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
import colossalai
from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool, TensorState
from colossalai.testing import run_on_environment_flag, spawn
def exam_chunk_group_functions(nproc, group):
def exam_chunk_group_functions(group):
a = torch.randn(3, 64, device='cuda')
copy_a = a.clone()
b = torch.randn(2, 32, device='cuda')
@ -23,7 +20,9 @@ def exam_chunk_group_functions(nproc, group):
copy_e = e.clone()
mp = MemoryPool('cuda')
mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)])
mp.allocate_public_blocks(block_num=2, block_spec=BlockSpec(numel=256, dtype=torch.float))
mp.allocate_private_blocks([BlockSpec(68, torch.float)])
cg = ChunkGroup(rcache=mp)
c0 = cg.allocate_chunk([a, b], 256, torch.float, group)
c1 = cg.allocate_chunk([c], 256, torch.float, group)
@ -76,22 +75,15 @@ def exam_chunk_group_functions(nproc, group):
print('chunk group functions are ok')
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_chunk_group_functions(group=dist.GroupMember.WORLD)
@pytest.mark.dist
@pytest.mark.parametrize('world_size', [1, 2, 4])
@run_on_environment_flag('ELX')
def test_chunk_group(world_size):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
spawn(run_dist, nprocs=world_size)
if __name__ == '__main__':

31
tests/test_elixir/test_chunk/test_scheduler.py

@ -1,19 +1,16 @@
import os
from functools import partial
import pytest
import torch
import torch.distributed as dist
import colossalai
from colossalai.elixir.chunk import Chunk, MemoryPool
from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler
from colossalai.elixir.utils import init_distributed
from colossalai.testing import run_on_environment_flag
from colossalai.testing import spawn
def exam_fifo(nproc, group):
def exam_fifo(group):
mp = MemoryPool('cuda')
mp.allocate(public_block_number=1)
mp.allocate_public_blocks(block_num=1)
c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group)
@ -40,9 +37,8 @@ def exam_fifo(nproc, group):
assert sdl.top() == c0
def exam_prefetch(nproc, group):
def exam_prefetch(group):
mp = MemoryPool('cuda')
mp.allocate()
c0 = Chunk(mp, 1024, torch.float, group)
c1 = Chunk(mp, 1024, torch.float, group)
c2 = Chunk(mp, 1024, torch.float, group)
@ -108,22 +104,15 @@ def exam_prefetch(nproc, group):
sdl.clear()
def run_dist(rank, world_size):
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(rank)
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['MASTER_ADDR'] = '127.0.0.1'
os.environ['MASTER_PORT'] = str(29512)
init_distributed()
exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD)
exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD)
def run_dist(rank, world_size, port):
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
exam_fifo(group=dist.GroupMember.WORLD)
exam_prefetch(group=dist.GroupMember.WORLD)
@pytest.mark.dist
@run_on_environment_flag('ELX')
def test_chunk_scheduler(world_size=1):
run_func = partial(run_dist, world_size=world_size)
torch.multiprocessing.spawn(run_func, nprocs=world_size)
spawn(run_dist, nprocs=world_size)
if __name__ == '__main__':

4
tests/test_elixir/test_ctx/test_meta_ctx.py

@ -1,9 +1,7 @@
from colossalai.elixir.ctx import MetaContext
from colossalai.testing import run_on_environment_flag
from colossalai.elixir.context import MetaContext
from tests.test_elixir.utils import TEST_MODELS
@run_on_environment_flag('ELX')
def test_meta_context():
builder, *_ = TEST_MODELS.get('resnet')
with MetaContext():

Loading…
Cancel
Save