|
|
@ -1,35 +1,52 @@ |
|
|
|
from abc import ABC |
|
|
|
from abc import ABC |
|
|
|
from collections import defaultdict |
|
|
|
from collections import defaultdict |
|
|
|
|
|
|
|
from enum import Enum |
|
|
|
from typing import Iterable, NamedTuple |
|
|
|
from typing import Iterable, NamedTuple |
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
import torch |
|
|
|
from torch.autograd.profiler_util import _format_memory |
|
|
|
from torch.autograd.profiler_util import _format_memory |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlockRequire(NamedTuple): |
|
|
|
class BlockSpec(NamedTuple): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
BlockSpec is the specification of a block. It contains the number of elements and the data type of the block. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
numel (int): the number of elements in the block |
|
|
|
|
|
|
|
dtype (torch.dtype): the data type of the block |
|
|
|
|
|
|
|
""" |
|
|
|
numel: int |
|
|
|
numel: int |
|
|
|
dtype: torch.dtype |
|
|
|
dtype: torch.dtype |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class BlockType(Enum): |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
BlockType is the type of a block. There are two types of blocks: public and private. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
PUBLIC = 0 |
|
|
|
|
|
|
|
PRIVATE = 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorBlock(ABC): |
|
|
|
class TensorBlock(ABC): |
|
|
|
"""TensorBlock is the memory unit of memory pool. |
|
|
|
""" |
|
|
|
It is a continuous memory block used to store tensors. |
|
|
|
TensorBlock is the memory unit of memory pool. It is a contiguous memory block used to store tensors. |
|
|
|
Each chunk needs a corresponding TensorBlock to store its data during training. |
|
|
|
Each chunk needs a corresponding TensorBlock to store its data during training. |
|
|
|
|
|
|
|
|
|
|
|
args: |
|
|
|
args: |
|
|
|
numel: the number of elements in the block |
|
|
|
size (int): the number of elements in the block |
|
|
|
dtype: the data type of the block |
|
|
|
dtype (torch.dtype): the data type of the block |
|
|
|
device_type: the device type of the block |
|
|
|
device_type (str): the device type of the block |
|
|
|
""" |
|
|
|
""" |
|
|
|
total_count: int = 0 |
|
|
|
total_count: int = 0 |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: |
|
|
|
def __init__(self, size: int, dtype: torch.dtype, device_type: str, block_type: BlockType) -> None: |
|
|
|
self.block_id = TensorBlock.total_count |
|
|
|
self.block_id = TensorBlock.total_count |
|
|
|
TensorBlock.total_count += 1 |
|
|
|
TensorBlock.total_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
self.device_type = device_type |
|
|
|
self.device_type = device_type |
|
|
|
self.payload: torch.Tensor = torch.empty((numel,), dtype=dtype, device=device_type) |
|
|
|
self.payload: torch.Tensor = torch.empty((size,), dtype=dtype, device=device_type) |
|
|
|
self.memo_occ: int = self.payload.numel() * self.payload.element_size() |
|
|
|
self.size_in_bytes: int = self.payload.numel() * self.payload.element_size() |
|
|
|
|
|
|
|
self.block_type = block_type |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
@property |
|
|
|
def numel(self): |
|
|
|
def numel(self): |
|
|
@ -50,122 +67,145 @@ class TensorBlock(ABC): |
|
|
|
return self.block_id == other.block_id |
|
|
|
return self.block_id == other.block_id |
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str: |
|
|
|
def __repr__(self) -> str: |
|
|
|
return f'(id={self.block_id}, numel={self.numel}, device={self.device_type}, dtype={self.dtype}, memo={self.memo_occ})' |
|
|
|
return f'{self.block_type}(\n\tID = {self.block_id}, \n\tsize = {self.numel}, \n\tdevice = {self.device_type}, \n\tdtype = {self.dtype}, \n\tsize in bytes={self.size_in_bytes}\n)' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PublicBlock(TensorBlock): |
|
|
|
class PublicBlock(TensorBlock): |
|
|
|
"""Public blocks have the same length. |
|
|
|
""" |
|
|
|
Chunks of the same length can share the same public block. |
|
|
|
Public blocks have the same length. Chunks of the same length can share the same public block. |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: |
|
|
|
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: |
|
|
|
super().__init__(numel, dtype, device_type) |
|
|
|
super().__init__(numel, dtype, device_type, BlockType.PUBLIC) |
|
|
|
self.block_type = 'public' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str: |
|
|
|
|
|
|
|
return f'PublicBlock{super().__repr__()}' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PrivateBlock(TensorBlock): |
|
|
|
class PrivateBlock(TensorBlock): |
|
|
|
"""Private blocks may have different lengths. |
|
|
|
""" |
|
|
|
Each private chunk should use its own private block. |
|
|
|
Private blocks may have different lengths. Each private chunk should use its own private block. |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: |
|
|
|
def __init__(self, numel: int, dtype: torch.dtype, device_type: str) -> None: |
|
|
|
super().__init__(numel, dtype, device_type) |
|
|
|
super().__init__(numel, dtype, device_type, BlockType.PRIVATE) |
|
|
|
self.block_type = 'private' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str: |
|
|
|
|
|
|
|
return f'PrivateBlock{super().__repr__()}' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class MemoryPool(object): |
|
|
|
class MemoryPool(object): |
|
|
|
"""A memory pool consists of public blocks and private blocks. |
|
|
|
""" |
|
|
|
|
|
|
|
A memory pool consists of public blocks and private blocks. |
|
|
|
rCache uses memory pool to manage memory bolcks. |
|
|
|
rCache uses memory pool to manage memory bolcks. |
|
|
|
Users should allocate memory blocks before using it. |
|
|
|
Users should allocate memory blocks before using it. |
|
|
|
|
|
|
|
|
|
|
|
args: |
|
|
|
args: |
|
|
|
device_type: the device type of the memory pool |
|
|
|
device_type (str): the device type of the memory pool |
|
|
|
""" |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, device_type: str) -> None: |
|
|
|
def __init__(self, device_type: str) -> None: |
|
|
|
|
|
|
|
assert device_type in [ |
|
|
|
|
|
|
|
'cuda', 'cpu' |
|
|
|
|
|
|
|
], f'Expected device type to be cuda or cpu, but got an invalid device type: {device_type}' |
|
|
|
self.device_type: str = device_type |
|
|
|
self.device_type: str = device_type |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# public space |
|
|
|
|
|
|
|
# public space = number of public block x the public block size in bytes |
|
|
|
|
|
|
|
# all public blocks have the same block size |
|
|
|
self.public_space: int = 0 |
|
|
|
self.public_space: int = 0 |
|
|
|
self.public_block_size: int = 0 |
|
|
|
self.public_block_size: int = 0 |
|
|
|
self.public_dtype: torch.dtype = None |
|
|
|
self.public_dtype: torch.dtype = None |
|
|
|
|
|
|
|
|
|
|
|
self.public_free_blocks: list = None |
|
|
|
# create block holder and counter |
|
|
|
self.public_used_blocks: set = None |
|
|
|
self.public_free_blocks: list = list() |
|
|
|
|
|
|
|
self.public_used_blocks: set = set() |
|
|
|
self.public_free_cnt: int = 0 |
|
|
|
self.public_free_count: int = 0 |
|
|
|
self.public_used_cnt: int = 0 |
|
|
|
self.public_used_count: int = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# private space |
|
|
|
|
|
|
|
# private look up dict returns an empty list if the block is not found |
|
|
|
self.private_space: int = 0 |
|
|
|
self.private_space: int = 0 |
|
|
|
self.private_blocks: list = None |
|
|
|
self.private_blocks: list = list() |
|
|
|
self.private_lookup_dict: dict[BlockRequire, list] = None |
|
|
|
self.private_lookup_dict: dict[BlockSpec, list] = defaultdict(list) |
|
|
|
|
|
|
|
|
|
|
|
self.__allocate_flag = False |
|
|
|
# flags for block allcation |
|
|
|
|
|
|
|
self.__public_allocated_flag = False |
|
|
|
def allocate(self, |
|
|
|
self.__private_allocated_flag = False |
|
|
|
public_dtype: torch.dtype = torch.float, |
|
|
|
|
|
|
|
public_block_size: int = 1024, |
|
|
|
def allocate_public_blocks(self, block_num: int, block_spec: BlockSpec = None): |
|
|
|
public_block_number: int = 0, |
|
|
|
""" |
|
|
|
private_block_list: Iterable[BlockRequire] = ()): |
|
|
|
Allocate public tensor blocks for the memory pool. This method will allocate public_block_number blocks with size equal to public_block_size. |
|
|
|
assert self.__allocate_flag is False |
|
|
|
""" |
|
|
|
assert public_block_number >= 0 |
|
|
|
assert not self.__public_allocated_flag, 'Public blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.' |
|
|
|
|
|
|
|
assert block_num >= 0, f'Expected public_block_number >= 0, but got {block_num}' |
|
|
|
self.public_free_blocks = list() |
|
|
|
|
|
|
|
self.public_used_blocks = set() |
|
|
|
if block_spec is None: |
|
|
|
for _ in range(public_block_number): |
|
|
|
block_spec = BlockSpec(numel=1024, dtype=torch.float) |
|
|
|
block = PublicBlock(public_block_size, public_dtype, self.device_type) |
|
|
|
|
|
|
|
|
|
|
|
# allocate public blocks |
|
|
|
|
|
|
|
for _ in range(block_num): |
|
|
|
|
|
|
|
block = PublicBlock(block_spec.numel, block_spec.dtype, self.device_type) |
|
|
|
self.public_free_blocks.append(block) |
|
|
|
self.public_free_blocks.append(block) |
|
|
|
|
|
|
|
self.public_space += block.size_in_bytes |
|
|
|
if public_block_number <= 0: |
|
|
|
self.public_free_count += 1 |
|
|
|
self.public_space = 0 |
|
|
|
|
|
|
|
else: |
|
|
|
# store the block spec info |
|
|
|
self.public_space = self.public_free_blocks[0].memo_occ * public_block_number |
|
|
|
self.public_block_size = block_spec.numel |
|
|
|
self.public_block_size = public_block_size |
|
|
|
self.public_dtype = block_spec.dtype |
|
|
|
self.public_dtype = public_dtype |
|
|
|
|
|
|
|
|
|
|
|
def allocate_private_blocks(self, block_specs: Iterable[BlockSpec]): |
|
|
|
self.public_free_cnt = public_block_number |
|
|
|
""" |
|
|
|
self.public_used_cnt = 0 |
|
|
|
Allocate private blocks for the memory pool. This method will allocate private blocks according to the block_specs. |
|
|
|
|
|
|
|
|
|
|
|
self.private_space = 0 |
|
|
|
Args: |
|
|
|
self.private_blocks = list() |
|
|
|
block_specs (Iterable[BlockSpec]): the block specs of the private blocks to be allocated |
|
|
|
self.private_lookup_dict = defaultdict(list) |
|
|
|
""" |
|
|
|
|
|
|
|
# allocate private blocks |
|
|
|
for require in private_block_list: |
|
|
|
assert not self.__private_allocated_flag, 'Private blocks have been allocated to this MemoryPool object, it is not allowed to allocate again.' |
|
|
|
block = PrivateBlock(require.numel, require.dtype, self.device_type) |
|
|
|
|
|
|
|
self.private_space += block.memo_occ |
|
|
|
for spec in block_specs: |
|
|
|
|
|
|
|
block = PrivateBlock(spec.numel, spec.dtype, self.device_type) |
|
|
|
|
|
|
|
self.private_space += block.size_in_bytes |
|
|
|
self.private_blocks.append(block) |
|
|
|
self.private_blocks.append(block) |
|
|
|
self.private_lookup_dict[require].append(block) |
|
|
|
self.private_lookup_dict[spec].append(block) |
|
|
|
|
|
|
|
|
|
|
|
self.__allocate_flag = True |
|
|
|
self.__private_allocated_flag = True |
|
|
|
|
|
|
|
|
|
|
|
def __repr__(self) -> str: |
|
|
|
def __repr__(self) -> str: |
|
|
|
return f'MP(public_space={_format_memory(self.public_space)}, private_space={_format_memory(self.private_space)})' |
|
|
|
return f'Memory Pool(\n\tpublic_space = {_format_memory(self.public_space)}, \n\tprivate_space={_format_memory(self.private_space)}\n)' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_private_block(self, numel: int, dtype: torch.dtype) -> PrivateBlock: |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Get a private block with the given numel and dtype. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
block_list = self.private_lookup_dict.get(BlockSpec(numel=numel, dtype=dtype)) |
|
|
|
|
|
|
|
|
|
|
|
def get_private_block(self, numel: int, dtype: torch.dtype): |
|
|
|
if len(block_list) == 0: |
|
|
|
block_list = self.private_lookup_dict.get(BlockRequire(numel=numel, dtype=dtype)) |
|
|
|
raise ValueError(f'No private block with numel={numel} and dtype={dtype} is found.') |
|
|
|
return block_list.pop() |
|
|
|
else: |
|
|
|
|
|
|
|
return block_list.pop() |
|
|
|
|
|
|
|
|
|
|
|
def get_public_block(self): |
|
|
|
def pop_public_block(self) -> PublicBlock: |
|
|
|
self.public_free_cnt -= 1 |
|
|
|
""" |
|
|
|
self.public_used_cnt += 1 |
|
|
|
Get a public block from the memory pool. |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
self.public_free_count -= 1 |
|
|
|
|
|
|
|
self.public_used_count += 1 |
|
|
|
|
|
|
|
|
|
|
|
block = self.public_free_blocks.pop() |
|
|
|
block = self.public_free_blocks.pop() |
|
|
|
self.public_used_blocks.add(block) |
|
|
|
self.public_used_blocks.add(block) |
|
|
|
|
|
|
|
|
|
|
|
return block |
|
|
|
return block |
|
|
|
|
|
|
|
|
|
|
|
def free_public_block(self, block: TensorBlock): |
|
|
|
def free_public_block(self, block: TensorBlock) -> PublicBlock: |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
Free a public block to the memory pool. |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Args: |
|
|
|
|
|
|
|
block (TensorBlock): the public block to be freed |
|
|
|
|
|
|
|
""" |
|
|
|
assert isinstance(block, PublicBlock) |
|
|
|
assert isinstance(block, PublicBlock) |
|
|
|
assert block in self.public_used_blocks |
|
|
|
assert block in self.public_used_blocks, f'Cound not find the given block in the used public blocks' |
|
|
|
|
|
|
|
|
|
|
|
self.public_free_cnt += 1 |
|
|
|
# update counter |
|
|
|
self.public_used_cnt -= 1 |
|
|
|
self.public_free_count += 1 |
|
|
|
|
|
|
|
self.public_used_count -= 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# update free and used blocks |
|
|
|
self.public_used_blocks.remove(block) |
|
|
|
self.public_used_blocks.remove(block) |
|
|
|
self.public_free_blocks.append(block) |
|
|
|
self.public_free_blocks.append(block) |
|
|
|
|
|
|
|
|
|
|
|