diff --git a/colossalai/elixir/chunk/__init__.py b/colossalai/elixir/chunk/__init__.py index 72d17dbc1..bf023576d 100644 --- a/colossalai/elixir/chunk/__init__.py +++ b/colossalai/elixir/chunk/__init__.py @@ -1,2 +1,7 @@ -from .core import BlockRequire, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState +from .core import BlockSpec, Chunk, ChunkGroup, MemoryPool, PrivateBlock, PublicBlock, TensorBlock, TensorState from .fetcher import ChunkFetcher + +__all__ = [ + 'BlockSpec', 'Chunk', 'ChunkGroup', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState', + 'ChunkFetcher' +] diff --git a/colossalai/elixir/chunk/core/__init__.py b/colossalai/elixir/chunk/core/__init__.py index 468d5428e..221630306 100644 --- a/colossalai/elixir/chunk/core/__init__.py +++ b/colossalai/elixir/chunk/core/__init__.py @@ -1,4 +1,8 @@ from .chunk import Chunk from .group import ChunkGroup -from .memory_pool import BlockRequire, MemoryPool, PrivateBlock, PublicBlock, TensorBlock +from .memory_pool import BlockSpec, MemoryPool, PrivateBlock, PublicBlock, TensorBlock from .states import TensorState + +__all__ = [ + 'Chunk', 'ChunkGroup', 'BlockSpec', 'MemoryPool', 'PrivateBlock', 'PublicBlock', 'TensorBlock', 'TensorState' +] diff --git a/colossalai/elixir/chunk/core/chunk.py b/colossalai/elixir/chunk/core/chunk.py index df570d289..b88824f58 100644 --- a/colossalai/elixir/chunk/core/chunk.py +++ b/colossalai/elixir/chunk/core/chunk.py @@ -8,8 +8,8 @@ from torch.distributed import ProcessGroup from colossalai.elixir.cuda import gpu_device from colossalai.elixir.tensor import FakeTensor -from .memory_pool import MemoryPool, PrivateBlock, PublicBlock, TensorBlock -from .states import TensorState, ts_update_sanity_check +from .memory_pool import MemoryPool, TensorBlock +from .states import TensorState, validate_tensor_state_update class ChunkFullError(Exception): @@ -383,7 +383,11 @@ class Chunk: prev_state = self.tensors_info[tensor].state if prev_state == tensor_state: return - if ts_update_sanity_check(prev_state, tensor_state): + + # validate whether the update is legal + # if illegal, raise an exception + is_update_valid = validate_tensor_state_update(prev_state, tensor_state, raise_exception=True) + if is_update_valid: self.__update_one_tensor_info(self.tensors_info[tensor], tensor_state) def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None: diff --git a/colossalai/elixir/chunk/core/group.py b/colossalai/elixir/chunk/core/group.py index 495040e51..7fc373289 100644 --- a/colossalai/elixir/chunk/core/group.py +++ b/colossalai/elixir/chunk/core/group.py @@ -129,7 +129,7 @@ class ChunkGroup(object): """Check whether the rcache has enough blocks to store the gathered chunk.""" if chunk.rcache_fused: return True - return self.rcache.public_free_cnt > 0 + return self.rcache.public_free_count > 0 def access_chunk(self, chunk: Chunk) -> bool: """Access a chunk into rCache.""" @@ -141,7 +141,7 @@ class ChunkGroup(object): if chunk.rcache_fused: block = None else: - block = self.rcache.get_public_block() + block = self.rcache.pop_public_block() chunk.access_chunk(block) self.__add_to_accset(chunk) return True diff --git a/colossalai/elixir/chunk/core/memory_pool.py b/colossalai/elixir/chunk/core/memory_pool.py index e73fc65a6..e124dffee 100644 --- a/colossalai/elixir/chunk/core/memory_pool.py +++ b/colossalai/elixir/chunk/core/memory_pool.py @@ -1,35 +1,52 @@ from abc import ABC from collections import defaultdict +from enum import Enum from typing import Iterable, NamedTuple import torch from torch.autograd.profiler_util import _format_memory -class BlockRequire(NamedTuple): +class BlockSpec(NamedTuple): + """ + BlockSpec is the specification of a block. It contains the number of elements and the data type of the block. + + Args: + numel (int): the number of elements in the block + dtype (torch.dtype): the data type of the block + """ numel: int dtype: torch.dtype +class BlockType(Enum): + """ + BlockType is the type of a block. There are two types of blocks: public and private. + """ + PUBLIC = 0 + PRIVATE = 1 + + class TensorBlock(ABC): - """TensorBlock is the memory unit of memory pool. - It is a continuous memory block used to store tensors. + """ + TensorBlock is the memory unit of memory pool. It is a contiguous memory block used to store tensors. Each chunk needs a corresponding TensorBlock to store its data during training. args: - numel: the number of elements in the block - dtype: the data type of the block - device_type: the device type of the block + size (int): the number of elements in the block + dtype (torch.dtype): the data type of the block + device_type (str): the device type of the block """ total_count: int = 0 - def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: + def __init__(self, size: int, dtype: torch.dtype, device_type: str, block_type: BlockType) -> None: self.block_id = TensorBlock.total_count TensorBlock.total_count += 1 self.device_type = device_type - self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type) - self.memo_occ: int = self.payload.numel() * self.payload.element_size() + self.payload: torch.Tensor = torch.empty((size,), dtype=dtype, device=device_type) + self.size_in_bytes: int = self.payload.numel() * self.payload.element_size() + self.block_type = block_type @property def numel(self): @@ -50,122 +67,145 @@ class TensorBlock(ABC): return self.block_id == other.block_id def __repr__(self) -> str: - return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})' + return f'{self.block_type}(\n\tID = {self.block_id}, \n\tsize = {self.numel}, \n\tdevice = {self.device_type}, \n\tdtype = {self.dtype}, \n\tsize in bytes={self.size_in_bytes}\n)' class PublicBlock(TensorBlock): - """Public blocks have the same length. - Chunks of the same length can share the same public block. + """ + Public blocks have the same length. Chunks of the same length can share the same public block. """ def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: - super().__init__(numel, dtype, device_type) - self.block_type = 'public' - - def __repr__(self) -> str: - return f'PublicBlock{super().__repr__()}' + super().__init__(numel, dtype, device_type, BlockType.PUBLIC) class PrivateBlock(TensorBlock): - """Private blocks may have different lengths. - Each private chunk should use its own private block. + """ + Private blocks may have different lengths. Each private chunk should use its own private block. """ def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: - super().__init__(numel, dtype, device_type) - self.block_type = 'private' - - def __repr__(self) -> str: - return f'PrivateBlock{super().__repr__()}' + super().__init__(numel, dtype, device_type, BlockType.PRIVATE) class MemoryPool(object): - """A memory pool consists of public blocks and private blocks. + """ + A memory pool consists of public blocks and private blocks. rCache uses memory pool to manage memory bolcks. Users should allocate memory blocks before using it. args: - device_type: the device type of the memory pool + device_type (str): the device type of the memory pool """ def __init__(self, device_type: str) -> None: + assert device_type in [ + 'cuda', 'cpu' + ], f'Expected device type to be cuda or cpu, but got an invalid device type: {device_type}' self.device_type: str = device_type + # public space + # public space = number of public block x the public block size in bytes + # all public blocks have the same block size self.public_space: int = 0 self.public_block_size: int = 0 self.public_dtype: torch.dtype = None - self.public_free_blocks: list = None - self.public_used_blocks: set = None - - self.public_free_cnt: int = 0 - self.public_used_cnt: int = 0 + # create block holder and counter + self.public_free_blocks: list = list() + self.public_used_blocks: set = set() + self.public_free_count: int = 0 + self.public_used_count: int = 0 + # private space + # private look up dict returns an empty list if the block is not found self.private_space: int = 0 - self.private_blocks: list = None - self.private_lookup_dict: dict[BlockRequire, list] = None - - self.__allocate_flag = False - - def allocate(self, - public_dtype: torch.dtype = torch.float, - public_block_size: int = 1024, - public_block_number: int = 0, - private_block_list: Iterable[BlockRequire] = ()): - assert self.__allocate_flag is False - assert public_block_number >= 0 - - self.public_free_blocks = list() - self.public_used_blocks = set() - for _ in range(public_block_number): - block = PublicBlock(public_block_size, public_dtype, self.device_type) + self.private_blocks: list = list() + self.private_lookup_dict: dict[BlockSpec, list] = defaultdict(list) + + # flags for block allcation + self.__public_allocated_flag = False + self.__private_allocated_flag = False + + def allocate_public_blocks(self, block_num: int, block_spec: BlockSpec = None): + """ + Allocate public tensor blocks for the memory pool. This method will allocate public_block_number blocks with size equal to public_block_size. + """ + assert not self.__public_allocated_flag, 'Public blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.' + assert block_num >= 0, f'Expected public_block_number >= 0, but got {block_num}' + + if block_spec is None: + block_spec = BlockSpec(numel=1024, dtype=torch.float) + + # allocate public blocks + for _ in range(block_num): + block = PublicBlock(block_spec.numel, block_spec.dtype, self.device_type) self.public_free_blocks.append(block) - - if public_block_number <= 0: - self.public_space = 0 - else: - self.public_space = self.public_free_blocks[0].memo_occ * public_block_number - self.public_block_size = public_block_size - self.public_dtype = public_dtype - - self.public_free_cnt = public_block_number - self.public_used_cnt = 0 - - self.private_space = 0 - self.private_blocks = list() - self.private_lookup_dict = defaultdict(list) - - for require in private_block_list: - block = PrivateBlock(require.numel, require.dtype, self.device_type) - self.private_space += block.memo_occ + self.public_space += block.size_in_bytes + self.public_free_count += 1 + + # store the block spec info + self.public_block_size = block_spec.numel + self.public_dtype = block_spec.dtype + + def allocate_private_blocks(self, block_specs: Iterable[BlockSpec]): + """ + Allocate private blocks for the memory pool. This method will allocate private blocks according to the block_specs. + + Args: + block_specs (Iterable[BlockSpec]): the block specs of the private blocks to be allocated + """ + # allocate private blocks + assert not self.__private_allocated_flag, 'Private blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.' + + for spec in block_specs: + block = PrivateBlock(spec.numel, spec.dtype, self.device_type) + self.private_space += block.size_in_bytes self.private_blocks.append(block) - self.private_lookup_dict[require].append(block) + self.private_lookup_dict[spec].append(block) - self.__allocate_flag = True + self.__private_allocated_flag = True def __repr__(self) -> str: - return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})' + return f'Memory Pool(\n\tpublic_space = {_format_memory(self.public_space)}, \n\tprivate_space={_format_memory(self.private_space)}\n)' + + def get_private_block(self, numel: int, dtype: torch.dtype) -> PrivateBlock: + """ + Get a private block with the given numel and dtype. + """ + block_list = self.private_lookup_dict.get(BlockSpec(numel=numel, dtype=dtype)) - def get_private_block(self, numel: int, dtype: torch.dtype): - block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype)) - return block_list.pop() + if len(block_list) == 0: + raise ValueError(f'No private block with numel={numel} and dtype={dtype} is found.') + else: + return block_list.pop() - def get_public_block(self): - self.public_free_cnt -= 1 - self.public_used_cnt += 1 + def pop_public_block(self) -> PublicBlock: + """ + Get a public block from the memory pool. + """ + self.public_free_count -= 1 + self.public_used_count += 1 block = self.public_free_blocks.pop() self.public_used_blocks.add(block) - return block - def free_public_block(self, block: TensorBlock): + def free_public_block(self, block: TensorBlock) -> PublicBlock: + """ + Free a public block to the memory pool. + + Args: + block (TensorBlock): the public block to be freed + """ assert isinstance(block, PublicBlock) - assert block in self.public_used_blocks + assert block in self.public_used_blocks, f'Cound not find the given block in the used public blocks' - self.public_free_cnt += 1 - self.public_used_cnt -= 1 + # update counter + self.public_free_count += 1 + self.public_used_count -= 1 + # update free and used blocks self.public_used_blocks.remove(block) self.public_free_blocks.append(block) diff --git a/colossalai/elixir/chunk/core/states.py b/colossalai/elixir/chunk/core/states.py index 90d4c9260..721221592 100644 --- a/colossalai/elixir/chunk/core/states.py +++ b/colossalai/elixir/chunk/core/states.py @@ -2,6 +2,10 @@ from enum import Enum class TensorState(Enum): + """ + TensorState represents the state of a tensor in Elixir. + There are five states of a tensor: free, compute, hold, hold_after_bwd, ready_for_reduce. + """ FREE = 0 COMPUTE = 1 HOLD = 2 @@ -9,17 +13,35 @@ class TensorState(Enum): READY_FOR_REDUCE = 4 -# expected: free -> hold -> compute -> hold -> +# this includes the possible state transition in tensor state: +# the item in the list is in the format of (old_state, new_state) +# the complete state transtition is: +# free -> hold -> compute -> hold -> # -> compute -> hold_after_bwd -> ready_for_reduce -legal_ts_update_list = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), - (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), - (TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), - (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), - (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), - (TensorState.READY_FOR_REDUCE, TensorState.HOLD)] +LEGAL_TENSOR_STATE_UPDATE_LIST = [(TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE), + (TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE), + (TensorState.COMPUTE, TensorState.HOLD), + (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD), + (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE), + (TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), + (TensorState.READY_FOR_REDUCE, TensorState.HOLD)] -def ts_update_sanity_check(old_state, new_state) -> bool: - if (old_state, new_state) not in legal_ts_update_list: - raise RuntimeError(f'illegal tensor state updating: {old_state} -> {new_state}') +def validate_tensor_state_update(old_state: TensorState, new_state: TensorState, raise_exception: bool = False) -> bool: + """ + Validate the tensor state update is legal or not. + + Args: + old_state (TensorState): the old state of the tensor + new_state (TensorState): the new state of the tensor + raise_exception (bool, optional): whether to raise exception when the state update is illegal. Defaults to False. + + Returns: + bool: whether the state update is legal or not. + """ + if (old_state, new_state) not in LEGAL_TENSOR_STATE_UPDATE_LIST: + if raise_exception: + raise RuntimeError(f'Found illegal tensor state updating: {old_state} -> {new_state}') + else: + return False return True diff --git a/colossalai/elixir/context/__init__.py b/colossalai/elixir/context/__init__.py new file mode 100644 index 000000000..2cbcbe480 --- /dev/null +++ b/colossalai/elixir/context/__init__.py @@ -0,0 +1,3 @@ +from .meta_context import MetaContext + +__all__ = ['MetaContext'] diff --git a/colossalai/elixir/ctx/__init__.py b/colossalai/elixir/context/meta_context.py similarity index 61% rename from colossalai/elixir/ctx/__init__.py rename to colossalai/elixir/context/meta_context.py index 6c56ea0c5..ee0535ac8 100644 --- a/colossalai/elixir/ctx/__init__.py +++ b/colossalai/elixir/context/meta_context.py @@ -1,6 +1,6 @@ import torch -tensor_creation_methods = dict(tensor=torch.tensor, +TESNOR_CREATION_METHODS = dict(tensor=torch.tensor, sparse_coo_tensor=torch.sparse_coo_tensor, asarray=torch.asarray, as_tensor=torch.as_tensor, @@ -29,4 +29,34 @@ tensor_creation_methods = dict(tensor=torch.tensor, polar=torch.polar, heaviside=torch.heaviside) -from .meta_ctx import MetaContext + +# TODO: unify this with lazy init context +class MetaContext(object): + """A context manager that wraps all tensor creation methods in torch. + By default, all tensors will be created in meta. + + args: + device_type: The device type of the tensors to be created. + """ + + def __init__(self, device_type: str = 'meta') -> None: + super().__init__() + self.device_type = device_type + return None + + def __enter__(self): + + def meta_wrap(func): + + def wrapped_func(*args, **kwargs): + kwargs['device'] = self.device_type + return func(*args, **kwargs) + + return wrapped_func + + for name, method in TESNOR_CREATION_METHODS.items(): + setattr(torch, name, meta_wrap(method)) + + def __exit__(self, exc_type, exc_val, exc_tb): + for name, method in TESNOR_CREATION_METHODS.items(): + setattr(torch, name, method) diff --git a/colossalai/elixir/ctx/meta_ctx.py b/colossalai/elixir/ctx/meta_ctx.py deleted file mode 100644 index 7710a5971..000000000 --- a/colossalai/elixir/ctx/meta_ctx.py +++ /dev/null @@ -1,34 +0,0 @@ -import torch - -from colossalai.elixir.ctx import tensor_creation_methods - - -class MetaContext(object): - """A context manager that wraps all tensor creation methods in torch. - By default, all tensors will be created in meta. - - args: - device_type: The device type of the tensors to be created. - """ - - def __init__(self, device_type: str = 'meta') -> None: - super().__init__() - self.device_type = device_type - return None - - def __enter__(self): - - def meta_wrap(func): - - def wrapped_func(*args, **kwargs): - kwargs['device'] = self.device_type - return func(*args, **kwargs) - - return wrapped_func - - for name, method in tensor_creation_methods.items(): - setattr(torch, name, meta_wrap(method)) - - def __exit__(self, exc_type, exc_val, exc_tb): - for name, method in tensor_creation_methods.items(): - setattr(torch, name, method) diff --git a/colossalai/elixir/kernels/__init__.py b/colossalai/elixir/kernels/__init__.py index 3ca4a2614..1e390e7e2 100644 --- a/colossalai/elixir/kernels/__init__.py +++ b/colossalai/elixir/kernels/__init__.py @@ -1,4 +1,3 @@ -import torch import torch.nn.functional as F fused_torch_functions = {F.layer_norm: F.layer_norm} @@ -12,6 +11,3 @@ def register_fused_layer_norm(): except: print('Cannot import fused layer norm, please install apex from source.') pass - - -register_fused_layer_norm() diff --git a/colossalai/elixir/search/base.py b/colossalai/elixir/search/base.py index 50baa27fa..a71fcf9ec 100644 --- a/colossalai/elixir/search/base.py +++ b/colossalai/elixir/search/base.py @@ -5,7 +5,7 @@ from typing import List, Tuple import torch import torch.nn as nn -from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool +from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool from colossalai.elixir.tracer.param_tracer import generate_tf_order from colossalai.elixir.tracer.utils import meta_copy from colossalai.elixir.utils import print_rank_0 @@ -119,7 +119,7 @@ class SearchBase(ABC): for plan in chunk_plans: kwargs = plan.kwargs if kwargs.get('rcache_fused', False): - block_require_list.append(BlockRequire(plan.chunk_size, plan.chunk_dtype)) + block_require_list.append(BlockSpec(plan.chunk_size, plan.chunk_dtype)) mp = MemoryPool('cuda') mp.allocate(public_dtype=self.unified_dtype, diff --git a/tests/test_elixir/test_chunk/fetcher_utils.py b/tests/test_elixir/test_chunk/fetcher_utils.py index 22caedee6..e81165e66 100644 --- a/tests/test_elixir/test_chunk/fetcher_utils.py +++ b/tests/test_elixir/test_chunk/fetcher_utils.py @@ -4,7 +4,7 @@ import torch import torch.distributed as dist import torch.nn as nn -from colossalai.elixir.chunk import BlockRequire, ChunkFetcher, ChunkGroup, MemoryPool, TensorState +from colossalai.elixir.chunk import BlockSpec, ChunkFetcher, ChunkGroup, MemoryPool, TensorState from colossalai.elixir.chunk.scheduler import FIFOScheduler from colossalai.elixir.hook import BufferStore, HookParam from colossalai.elixir.tensor import OutplaceTensor @@ -32,13 +32,15 @@ def grad_handler(grad: torch.Tensor, param: nn.Parameter, fetcher: ChunkFetcher) def hook_transform(model: nn.Module, process_group: dist.ProcessGroupGloo): pg_size = dist.get_world_size(process_group) - private_list = list() + mp = MemoryPool('cuda') + + # allocate private blocks + private_block_specs = list() for param in model.parameters(): block_size = to_divide(param.numel(), pg_size) - private_list.append(BlockRequire(block_size, param.dtype)) + private_block_specs.append(BlockSpec(block_size, param.dtype)) + mp.allocate_private_blocks(private_block_specs) - mp = MemoryPool('cuda') - mp.allocate(private_block_list=private_list) cg = ChunkGroup(rcache=mp) # allocate chunk group fused_config = dict(rcache_fused=True) diff --git a/tests/test_elixir/test_chunk/test_block.py b/tests/test_elixir/test_chunk/test_block.py index bec8bf42b..b169e86cc 100644 --- a/tests/test_elixir/test_chunk/test_block.py +++ b/tests/test_elixir/test_chunk/test_block.py @@ -1,60 +1,62 @@ import torch -from colossalai.elixir.chunk import BlockRequire, MemoryPool, PrivateBlock, PublicBlock +from colossalai.elixir.chunk import BlockSpec, MemoryPool, PrivateBlock, PublicBlock from colossalai.testing import run_on_environment_flag -@run_on_environment_flag('ELX') def test_block(): - b = PublicBlock(123, torch.float16, 'cuda') - payload_b = b.payload - - assert payload_b.numel() == 123 - assert payload_b.dtype == torch.float16 - assert payload_b.device.type == 'cuda' - assert payload_b.numel() * payload_b.element_size() == b.memo_occ - - c = PrivateBlock(77, torch.float, 'cpu') - payload_c = c.payload - - assert payload_c.numel() == 77 - assert payload_c.dtype == torch.float - assert payload_c.device.type == 'cpu' - assert payload_c.numel() * payload_c.element_size() == c.memo_occ - + # test for public block + public_block = PublicBlock(123, torch.float16, 'cuda') + public_payload = public_block.payload + + assert public_payload.numel() == 123 + assert public_payload.dtype == torch.float16 + assert public_payload.device.type == 'cuda' + assert public_payload.numel() * public_payload.element_size() == public_block.size_in_bytes + + # test for private block + private_block = PrivateBlock(77, torch.float, 'cpu') + private_payload = private_block.payload + + assert private_payload.numel() == 77 + assert private_payload.dtype == torch.float + assert private_payload.device.type == 'cpu' + assert private_payload.numel() * private_payload.element_size() == private_block.size_in_bytes print('test_block: ok') -@run_on_environment_flag('ELX') def test_memory_pool(): mp = MemoryPool(device_type='cuda') - private_list = [BlockRequire(5, torch.float), BlockRequire(81, torch.float16)] - mp.allocate(public_block_number=4, private_block_list=private_list) - block0 = mp.get_public_block() + # allocate public blocks + mp.allocate_public_blocks(block_num=4) - assert block0 in mp.public_used_blocks - assert mp.public_used_cnt == 1 - assert mp.public_free_cnt == 3 + # allocate private blocks + private_block_specs = [BlockSpec(5, torch.float), BlockSpec(81, torch.float16)] + mp.allocate_private_blocks(private_block_specs) - block1 = mp.get_public_block() + # test for public blocks + block0 = mp.pop_public_block() + assert block0 in mp.public_used_blocks + assert mp.public_used_count == 1 + assert mp.public_free_count == 3 + block1 = mp.pop_public_block() assert block1 in mp.public_used_blocks - assert mp.public_used_cnt == 2 - assert mp.public_free_cnt == 2 + assert mp.public_used_count == 2 + assert mp.public_free_count == 2 mp.free_public_block(block0) mp.free_public_block(block1) - assert block0 in mp.public_free_blocks assert block1 in mp.public_free_blocks - assert mp.public_used_cnt == 0 - assert mp.public_free_cnt == 4 + assert mp.public_used_count == 0 + assert mp.public_free_count == 4 + # test for private block block0 = mp.get_private_block(5, torch.float) assert block0.numel == 5 assert block0.dtype == torch.float - print('test_memory_pool: ok') diff --git a/tests/test_elixir/test_chunk/test_chunk.py b/tests/test_elixir/test_chunk/test_chunk.py index f7fb9a0dd..0dfb140f9 100644 --- a/tests/test_elixir/test_chunk/test_chunk.py +++ b/tests/test_elixir/test_chunk/test_chunk.py @@ -1,13 +1,10 @@ -import os -from functools import partial - import pytest import torch import torch.distributed as dist -from colossalai.elixir.chunk import BlockRequire, Chunk, MemoryPool, TensorState -from colossalai.elixir.utils import init_distributed -from colossalai.testing import run_on_environment_flag +import colossalai +from colossalai.elixir.chunk import BlockSpec, Chunk, MemoryPool, TensorState +from colossalai.testing import run_on_environment_flag, spawn def exam_chunk_functions(nproc, group): @@ -21,7 +18,7 @@ def exam_chunk_functions(nproc, group): copy_d = d.clone() mp = MemoryPool('cuda') - mp.allocate(public_block_number=1) + mp.allocate_public_blocks(block_num=1) chunk = Chunk(mp, 1024, torch.float, group) chunk.l2_norm_flag = True @@ -43,26 +40,31 @@ def exam_chunk_functions(nproc, group): chunk.close_chunk() assert chunk.is_replica is False + # check function: get_cpu_copy cpu_copys = chunk.get_cpu_copy() for t_gpu, t_cpu in zip([copy_a, copy_b, copy_c, copy_d], cpu_copys): assert t_cpu.device.type == 'cpu' assert torch.equal(t_gpu.cpu(), t_cpu) + # check function: access_chunk - block = mp.get_public_block() + block = mp.pop_public_block() chunk.access_chunk(block) assert chunk.is_replica assert chunk.scatter_check check_tensors() + # check function: release_chunk chunk.optim_sync_flag = False block = chunk.release_chunk() assert block in mp.public_used_blocks assert chunk.is_replica is False assert chunk.optim_sync_flag is True + # check function: access_chunk after release_chunk chunk.access_chunk(block) check_tensors() + # check function: reduce_chunk norm = block.payload.float().norm(2)**2 chunk.reduce_chunk() @@ -87,9 +89,10 @@ def exam_chunk_states(nproc, group): d = torch.randn(4, 32, device='cuda') copy_d = d.clone() - private = [BlockRequire(1024, torch.float)] mp = MemoryPool('cuda') - mp.allocate(private_block_list=private) + + private_block_specs = [BlockSpec(1024, torch.float)] + mp.allocate_private_blocks(private_block_specs) chunk = Chunk(mp, 1024, torch.float, group, rcache_fused=True) assert chunk.chunk_size == 1024 @@ -132,23 +135,16 @@ def exam_chunk_states(nproc, group): print('chunk states are ok') -def run_dist(rank, world_size): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(29512) - init_distributed() +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') exam_chunk_functions(nproc=world_size, group=dist.GroupMember.WORLD) exam_chunk_states(nproc=world_size, group=dist.GroupMember.WORLD) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2, 4]) -@run_on_environment_flag('ELX') def test_chunk_functions(world_size): - run_func = partial(run_dist, world_size=world_size) - torch.multiprocessing.spawn(run_func, nprocs=world_size) + spawn(run_dist, nprocs=world_size) if __name__ == '__main__': diff --git a/tests/test_elixir/test_chunk/test_fetcher.py b/tests/test_elixir/test_chunk/test_fetcher.py index 27906b18d..a3f070291 100644 --- a/tests/test_elixir/test_chunk/test_fetcher.py +++ b/tests/test_elixir/test_chunk/test_fetcher.py @@ -1,6 +1,4 @@ import copy -import os -from functools import partial import pytest import torch @@ -8,9 +6,10 @@ import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel as DDP from torch.testing import assert_close +import colossalai from colossalai.elixir.chunk import ChunkGroup -from colossalai.elixir.utils import init_distributed, seed_all -from colossalai.testing import run_on_environment_flag +from colossalai.elixir.utils import seed_all +from colossalai.testing import run_on_environment_flag, spawn from tests.test_elixir.test_chunk.fetcher_utils import hook_transform from tests.test_elixir.utils import TEST_MODELS, to_cuda @@ -25,7 +24,7 @@ def check_gradient(ddp_model, my_model, cg: ChunkGroup): assert_close(p0.grad.data, p1.data) -def exam_chunk_fetcher(nproc, group): +def exam_chunk_fetcher(group): model_fn, data_fn = TEST_MODELS.get('resnet') torch_model = model_fn().cuda() test_model = copy.deepcopy(torch_model) @@ -49,23 +48,17 @@ def exam_chunk_fetcher(nproc, group): print('private chunk fetcher is ok') -def run_dist(rank, world_size): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(29512) - init_distributed() - exam_chunk_fetcher(nproc=world_size, group=dist.GroupMember.WORLD) +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + exam_chunk_fetcher(group=dist.GroupMember.WORLD) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2, 4]) -@run_on_environment_flag('ELX') def test_chunk_fetcher(world_size): - run_func = partial(run_dist, world_size=world_size) - torch.multiprocessing.spawn(run_func, nprocs=world_size) + spawn(run_dist, nprocs=world_size) if __name__ == '__main__': test_chunk_fetcher(world_size=2) + test_chunk_fetcher(world_size=2) diff --git a/tests/test_elixir/test_chunk/test_group.py b/tests/test_elixir/test_chunk/test_group.py index df183a9aa..2a8a4d662 100644 --- a/tests/test_elixir/test_chunk/test_group.py +++ b/tests/test_elixir/test_chunk/test_group.py @@ -1,16 +1,13 @@ -import os -from functools import partial - import pytest import torch import torch.distributed as dist -from colossalai.elixir.chunk import BlockRequire, ChunkGroup, MemoryPool, TensorState -from colossalai.elixir.utils import init_distributed -from colossalai.testing import run_on_environment_flag +import colossalai +from colossalai.elixir.chunk import BlockSpec, ChunkGroup, MemoryPool, TensorState +from colossalai.testing import run_on_environment_flag, spawn -def exam_chunk_group_functions(nproc, group): +def exam_chunk_group_functions(group): a = torch.randn(3, 64, device='cuda') copy_a = a.clone() b = torch.randn(2, 32, device='cuda') @@ -23,7 +20,9 @@ def exam_chunk_group_functions(nproc, group): copy_e = e.clone() mp = MemoryPool('cuda') - mp.allocate(public_block_size=256, public_block_number=2, private_block_list=[BlockRequire(68, torch.float)]) + mp.allocate_public_blocks(block_num=2, block_spec=BlockSpec(numel=256, dtype=torch.float)) + mp.allocate_private_blocks([BlockSpec(68, torch.float)]) + cg = ChunkGroup(rcache=mp) c0 = cg.allocate_chunk([a, b], 256, torch.float, group) c1 = cg.allocate_chunk([c], 256, torch.float, group) @@ -76,22 +75,15 @@ def exam_chunk_group_functions(nproc, group): print('chunk group functions are ok') -def run_dist(rank, world_size): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(29512) - init_distributed() - exam_chunk_group_functions(nproc=world_size, group=dist.GroupMember.WORLD) +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + exam_chunk_group_functions(group=dist.GroupMember.WORLD) @pytest.mark.dist @pytest.mark.parametrize('world_size', [1, 2, 4]) -@run_on_environment_flag('ELX') def test_chunk_group(world_size): - run_func = partial(run_dist, world_size=world_size) - torch.multiprocessing.spawn(run_func, nprocs=world_size) + spawn(run_dist, nprocs=world_size) if __name__ == '__main__': diff --git a/tests/test_elixir/test_chunk/test_scheduler.py b/tests/test_elixir/test_chunk/test_scheduler.py index d0e5a0f47..f95cfacd7 100644 --- a/tests/test_elixir/test_chunk/test_scheduler.py +++ b/tests/test_elixir/test_chunk/test_scheduler.py @@ -1,19 +1,16 @@ -import os -from functools import partial - import pytest import torch import torch.distributed as dist +import colossalai from colossalai.elixir.chunk import Chunk, MemoryPool from colossalai.elixir.chunk.scheduler import FIFOScheduler, PrefetchScheduler -from colossalai.elixir.utils import init_distributed -from colossalai.testing import run_on_environment_flag +from colossalai.testing import spawn -def exam_fifo(nproc, group): +def exam_fifo(group): mp = MemoryPool('cuda') - mp.allocate(public_block_number=1) + mp.allocate_public_blocks(block_num=1) c0 = Chunk(mp, 1024, torch.float, group) c1 = Chunk(mp, 1024, torch.float, group) c2 = Chunk(mp, 1024, torch.float, group) @@ -40,9 +37,8 @@ def exam_fifo(nproc, group): assert sdl.top() == c0 -def exam_prefetch(nproc, group): +def exam_prefetch(group): mp = MemoryPool('cuda') - mp.allocate() c0 = Chunk(mp, 1024, torch.float, group) c1 = Chunk(mp, 1024, torch.float, group) c2 = Chunk(mp, 1024, torch.float, group) @@ -108,22 +104,15 @@ def exam_prefetch(nproc, group): sdl.clear() -def run_dist(rank, world_size): - os.environ['RANK'] = str(rank) - os.environ['LOCAL_RANK'] = str(rank) - os.environ['WORLD_SIZE'] = str(world_size) - os.environ['MASTER_ADDR'] = '127.0.0.1' - os.environ['MASTER_PORT'] = str(29512) - init_distributed() - exam_fifo(nproc=world_size, group=dist.GroupMember.WORLD) - exam_prefetch(nproc=world_size, group=dist.GroupMember.WORLD) +def run_dist(rank, world_size, port): + colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') + exam_fifo(group=dist.GroupMember.WORLD) + exam_prefetch(group=dist.GroupMember.WORLD) @pytest.mark.dist -@run_on_environment_flag('ELX') def test_chunk_scheduler(world_size=1): - run_func = partial(run_dist, world_size=world_size) - torch.multiprocessing.spawn(run_func, nprocs=world_size) + spawn(run_dist, nprocs=world_size) if __name__ == '__main__': diff --git a/tests/test_elixir/test_ctx/test_meta_ctx.py b/tests/test_elixir/test_ctx/test_meta_ctx.py index 99d4ab1ec..3e7343b5d 100644 --- a/tests/test_elixir/test_ctx/test_meta_ctx.py +++ b/tests/test_elixir/test_ctx/test_meta_ctx.py @@ -1,9 +1,7 @@ -from colossalai.elixir.ctx import MetaContext -from colossalai.testing import run_on_environment_flag +from colossalai.elixir.context import MetaContext from tests.test_elixir.utils import TEST_MODELS -@run_on_environment_flag('ELX') def test_meta_context(): builder, *_ = TEST_MODELS.get('resnet') with MetaContext():