mirror of https://github.com/hpcaitech/ColossalAI
[refactor] move chunk and chunkmgr to directory gemini (#1182)
parent
6b2f2ab9bb
commit
372f791444
|
@ -1,5 +1,10 @@
|
||||||
|
from .chunk import TensorInfo, Chunk, TensorState
|
||||||
|
from .chunk_mgr import ChunkManager
|
||||||
from .stateful_tensor_mgr import StatefulTensorMgr
|
from .stateful_tensor_mgr import StatefulTensorMgr
|
||||||
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
from .tensor_placement_policy import TensorPlacementPolicyFactory
|
||||||
from .gemini_mgr import GeminiManager
|
from .gemini_mgr import GeminiManager
|
||||||
|
|
||||||
__all__ = ['StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager']
|
__all__ = [
|
||||||
|
'StatefulTensorMgr', 'TensorPlacementPolicyFactory', 'GeminiManager', 'ChunkManager', 'TensorInfo', 'Chunk',
|
||||||
|
'TensorState'
|
||||||
|
]
|
||||||
|
|
|
@ -0,0 +1,315 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional, Dict, List
|
||||||
|
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
|
class TensorState(Enum):
|
||||||
|
FREE = 0
|
||||||
|
COMPUTE = 1
|
||||||
|
HOLD = 2
|
||||||
|
HOLD_AFTER_BWD = 3
|
||||||
|
READY_FOR_REDUCE = 4
|
||||||
|
|
||||||
|
|
||||||
|
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
||||||
|
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
||||||
|
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
||||||
|
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
||||||
|
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
||||||
|
TensorState.HOLD))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TensorInfo:
|
||||||
|
state: TensorState
|
||||||
|
offset: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChunkFullError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
||||||
|
return tensor.storage().size() == 0
|
||||||
|
|
||||||
|
|
||||||
|
def free_storage(tensor: torch.Tensor) -> None:
|
||||||
|
if not is_storage_empty(tensor):
|
||||||
|
tensor.storage().resize_(0)
|
||||||
|
|
||||||
|
|
||||||
|
def alloc_storage(tensor: torch.Tensor) -> None:
|
||||||
|
if is_storage_empty(tensor):
|
||||||
|
tensor.storage().resize_(tensor.numel())
|
||||||
|
|
||||||
|
|
||||||
|
class Chunk:
|
||||||
|
"""
|
||||||
|
A chunk is a contiguous memory space which contains multiple tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chunk_size (int): the number of elements in a chunk
|
||||||
|
src_rank (int): the process which owns the chunk
|
||||||
|
dtype (torch.dtype): the data type of the chunk
|
||||||
|
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
|
||||||
|
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
chunk_size: int,
|
||||||
|
src_rank: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
init_device: Optional[torch.device] = None,
|
||||||
|
force_data_on_cuda: bool = False) -> None:
|
||||||
|
self.size = chunk_size
|
||||||
|
self.utilized_size = 0
|
||||||
|
self.src_rank = src_rank
|
||||||
|
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
||||||
|
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
||||||
|
self.dtype = dtype
|
||||||
|
device = init_device or get_current_device()
|
||||||
|
if force_data_on_cuda:
|
||||||
|
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
|
||||||
|
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
|
||||||
|
if device.type == 'cuda':
|
||||||
|
free_storage(self._cpu_data)
|
||||||
|
else:
|
||||||
|
free_storage(self.data)
|
||||||
|
else:
|
||||||
|
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
|
||||||
|
self._cpu_data = None
|
||||||
|
|
||||||
|
# we only keep the chunk in full in the process by which the tensor is owned
|
||||||
|
if not self.is_src_rank:
|
||||||
|
free_storage(self._payload)
|
||||||
|
|
||||||
|
# each tensor is associated with a TensorInfo to track meta info
|
||||||
|
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
||||||
|
self.mem = self.size * self.data.element_size()
|
||||||
|
|
||||||
|
def append(self, tensor: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Add a tensor to the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): a tensor to be added to the chunk
|
||||||
|
"""
|
||||||
|
assert tensor.dtype == self.dtype
|
||||||
|
new_utilized_size = self.utilized_size + tensor.numel()
|
||||||
|
|
||||||
|
# raise exception when the chunk size is exceeded
|
||||||
|
if new_utilized_size > self.size:
|
||||||
|
raise ChunkFullError
|
||||||
|
|
||||||
|
# set tensor state
|
||||||
|
tensor_state = TensorState.FREE
|
||||||
|
|
||||||
|
# if the process owns the rank, then copy the tensor to its chunk buffer
|
||||||
|
# otherwise set its storage size to 0 to reduce memory consumption
|
||||||
|
if self.is_src_rank:
|
||||||
|
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
||||||
|
tensor_state = TensorState.HOLD
|
||||||
|
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
||||||
|
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
||||||
|
else:
|
||||||
|
tensor.storage().resize_(0)
|
||||||
|
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
||||||
|
self.utilized_size = new_utilized_size
|
||||||
|
|
||||||
|
def release(self) -> None:
|
||||||
|
"""
|
||||||
|
Release the memory space on processes which do not own the chunk.
|
||||||
|
"""
|
||||||
|
if not self.is_src_rank:
|
||||||
|
free_storage(self._payload)
|
||||||
|
self._update_tensors_state(TensorState.FREE)
|
||||||
|
|
||||||
|
def _update_tensors_ptr(self) -> None:
|
||||||
|
assert type(self._payload) == torch.Tensor
|
||||||
|
for tensor, tensor_info in self.tensors_info.items():
|
||||||
|
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
|
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
||||||
|
for tensor_info in self.tensors_info.values():
|
||||||
|
if prev_state is None or tensor_info.state == prev_state:
|
||||||
|
tensor_info.state = next_state
|
||||||
|
|
||||||
|
def access(self) -> None:
|
||||||
|
"""
|
||||||
|
Broadcast the chunk to synchronize the tensors across data parallel processes.
|
||||||
|
"""
|
||||||
|
# recover the chunk on non-owner processes
|
||||||
|
# and broadcast the chunk from the source to all processes
|
||||||
|
if not self.is_src_rank:
|
||||||
|
alloc_storage(self._payload)
|
||||||
|
self.move_device(get_current_device(), update_ptr=False)
|
||||||
|
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
|
||||||
|
# update tensor meta info
|
||||||
|
self._update_tensors_ptr()
|
||||||
|
if not self.is_src_rank:
|
||||||
|
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
||||||
|
|
||||||
|
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
|
||||||
|
"""
|
||||||
|
Move the chunk to a target device.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (torch.device): the target device for data movement.
|
||||||
|
"""
|
||||||
|
if self._payload.device == device:
|
||||||
|
return
|
||||||
|
if self._cpu_data is None:
|
||||||
|
self.data.data = self.data.to(device)
|
||||||
|
else:
|
||||||
|
if device.type == 'cuda':
|
||||||
|
# cpu -> cuda
|
||||||
|
src = self._cpu_data
|
||||||
|
dest = self.data
|
||||||
|
else:
|
||||||
|
# cuda -> cpu
|
||||||
|
src = self.data
|
||||||
|
dest = self._cpu_data
|
||||||
|
alloc_storage(dest)
|
||||||
|
dest.copy_(src)
|
||||||
|
free_storage(src)
|
||||||
|
|
||||||
|
if update_ptr:
|
||||||
|
self._update_tensors_ptr()
|
||||||
|
|
||||||
|
def reduce(self, is_all_reduce: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Reduce or all-reduce the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
|
||||||
|
"""
|
||||||
|
self.move_device(get_current_device(), update_ptr=False)
|
||||||
|
if is_all_reduce:
|
||||||
|
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
else:
|
||||||
|
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
||||||
|
self._update_tensors_ptr()
|
||||||
|
self._update_tensors_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
||||||
|
"""
|
||||||
|
Make a transition of the tensor into the next state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): a torch Tensor object.
|
||||||
|
tensor_state (TensorState): the target state for transition.
|
||||||
|
"""
|
||||||
|
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
|
||||||
|
# As the gradient hook can be triggered either before or after post-backward
|
||||||
|
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
||||||
|
# or compute -> ready_for_reduce -> hold_after_bwd
|
||||||
|
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
|
||||||
|
# this function only apply valid state transformation
|
||||||
|
# invalid calls will be ignored and nothing changes
|
||||||
|
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
||||||
|
# print(
|
||||||
|
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
||||||
|
# )
|
||||||
|
return
|
||||||
|
self.tensors_info[tensor].state = tensor_state
|
||||||
|
|
||||||
|
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
||||||
|
"""
|
||||||
|
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor (torch.Tensor): the tensor used to retrive meta information
|
||||||
|
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
||||||
|
"""
|
||||||
|
tensor_info = self.tensors_info[tensor]
|
||||||
|
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
||||||
|
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_release(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether the chunk can be released.
|
||||||
|
"""
|
||||||
|
for tensor_info in self.tensors_info.values():
|
||||||
|
if tensor_info.state != TensorState.HOLD:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_move_device(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether the chunk can be moved across devices.
|
||||||
|
"""
|
||||||
|
for tensor_info in self.tensors_info.values():
|
||||||
|
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_reduce(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether the chunk can be reduced.
|
||||||
|
"""
|
||||||
|
for tensor_info in self.tensors_info.values():
|
||||||
|
if tensor_info.state != TensorState.READY_FOR_REDUCE:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check whether the chunk is empty.
|
||||||
|
"""
|
||||||
|
return is_storage_empty(self._payload)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_inf_or_nan(self) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the chunk has inf or nan values.
|
||||||
|
"""
|
||||||
|
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
|
||||||
|
torch.isnan(self._payload[:self.utilized_size]).any().item()
|
||||||
|
|
||||||
|
def copy_(self, dest_chunk: 'Chunk'):
|
||||||
|
"""
|
||||||
|
Copy the data of this chunk to a destination chunk.
|
||||||
|
"""
|
||||||
|
assert not self.is_empty
|
||||||
|
assert not dest_chunk.is_empty
|
||||||
|
assert self.size == dest_chunk.size
|
||||||
|
assert self.utilized_size == dest_chunk.utilized_size
|
||||||
|
self._payload.copy_(dest_chunk._payload)
|
||||||
|
self._update_tensors_ptr()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device_type(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the device type of the chunk.
|
||||||
|
"""
|
||||||
|
return self._payload.device.type
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash(id(self))
|
||||||
|
|
||||||
|
def __eq__(self, __o: object) -> bool:
|
||||||
|
return self is __o
|
||||||
|
|
||||||
|
def get_tensors(self) -> List[torch.Tensor]:
|
||||||
|
return list(self.tensors_info.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _payload(self) -> torch.Tensor:
|
||||||
|
if self._cpu_data is None or is_storage_empty(self._cpu_data):
|
||||||
|
return self.data
|
||||||
|
return self._cpu_data
|
|
@ -1,318 +1,11 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
from typing import Optional, Dict, Deque, Set, List, Tuple, Iterable
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
from .chunk import Chunk, ChunkFullError, TensorState
|
||||||
|
|
||||||
class TensorState(Enum):
|
|
||||||
FREE = 0
|
|
||||||
COMPUTE = 1
|
|
||||||
HOLD = 2
|
|
||||||
HOLD_AFTER_BWD = 3
|
|
||||||
READY_FOR_REDUCE = 4
|
|
||||||
|
|
||||||
|
|
||||||
STATE_TRANS = ((TensorState.FREE, TensorState.HOLD), (TensorState.FREE, TensorState.COMPUTE),
|
|
||||||
(TensorState.HOLD, TensorState.FREE), (TensorState.HOLD, TensorState.COMPUTE),
|
|
||||||
(TensorState.COMPUTE, TensorState.HOLD), (TensorState.COMPUTE, TensorState.HOLD_AFTER_BWD),
|
|
||||||
(TensorState.COMPUTE, TensorState.READY_FOR_REDUCE), (TensorState.HOLD_AFTER_BWD, TensorState.COMPUTE),
|
|
||||||
(TensorState.HOLD_AFTER_BWD, TensorState.READY_FOR_REDUCE), (TensorState.READY_FOR_REDUCE,
|
|
||||||
TensorState.HOLD))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TensorInfo:
|
|
||||||
state: TensorState
|
|
||||||
offset: int
|
|
||||||
end: int
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkFullError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def is_storage_empty(tensor: torch.Tensor) -> bool:
|
|
||||||
return tensor.storage().size() == 0
|
|
||||||
|
|
||||||
|
|
||||||
def free_storage(tensor: torch.Tensor) -> None:
|
|
||||||
if not is_storage_empty(tensor):
|
|
||||||
tensor.storage().resize_(0)
|
|
||||||
|
|
||||||
|
|
||||||
def alloc_storage(tensor: torch.Tensor) -> None:
|
|
||||||
if is_storage_empty(tensor):
|
|
||||||
tensor.storage().resize_(tensor.numel())
|
|
||||||
|
|
||||||
|
|
||||||
class Chunk:
|
|
||||||
"""
|
|
||||||
A chunk is a contiguous memory space which contains multiple tensors.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chunk_size (int): the number of elements in a chunk
|
|
||||||
src_rank (int): the process which owns the chunk
|
|
||||||
dtype (torch.dtype): the data type of the chunk
|
|
||||||
init_device (torch.device): optional, the device where the tensor is initialized. The default value is None, which is the current GPU.
|
|
||||||
force_data_on_cuda (bool): optional, if True, chunk.data is always on cuda. Defaults to False.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
chunk_size: int,
|
|
||||||
src_rank: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
init_device: Optional[torch.device] = None,
|
|
||||||
force_data_on_cuda: bool = False) -> None:
|
|
||||||
self.size = chunk_size
|
|
||||||
self.utilized_size = 0
|
|
||||||
self.src_rank = src_rank
|
|
||||||
self.is_src_rank = gpc.get_local_rank(ParallelMode.DATA) == src_rank
|
|
||||||
self.global_src_rank = gpc.get_ranks_in_group(ParallelMode.DATA)[src_rank]
|
|
||||||
self.dtype = dtype
|
|
||||||
device = init_device or get_current_device()
|
|
||||||
if force_data_on_cuda:
|
|
||||||
self.data = torch.empty(chunk_size, dtype=dtype, device=get_current_device())
|
|
||||||
self._cpu_data = torch.empty(chunk_size, dtype=dtype)
|
|
||||||
if device.type == 'cuda':
|
|
||||||
free_storage(self._cpu_data)
|
|
||||||
else:
|
|
||||||
free_storage(self.data)
|
|
||||||
else:
|
|
||||||
self.data = torch.empty(chunk_size, dtype=dtype, device=device)
|
|
||||||
self._cpu_data = None
|
|
||||||
|
|
||||||
# we only keep the chunk in full in the process by which the tensor is owned
|
|
||||||
if not self.is_src_rank:
|
|
||||||
free_storage(self._payload)
|
|
||||||
|
|
||||||
# each tensor is associated with a TensorInfo to track meta info
|
|
||||||
self.tensors_info: Dict[torch.Tensor, TensorInfo] = {}
|
|
||||||
self.mem = self.size * self.data.element_size()
|
|
||||||
|
|
||||||
def append(self, tensor: torch.Tensor) -> None:
|
|
||||||
"""
|
|
||||||
Add a tensor to the chunk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (torch.Tensor): a tensor to be added to the chunk
|
|
||||||
"""
|
|
||||||
assert tensor.dtype == self.dtype
|
|
||||||
new_utilized_size = self.utilized_size + tensor.numel()
|
|
||||||
|
|
||||||
# raise exception when the chunk size is exceeded
|
|
||||||
if new_utilized_size > self.size:
|
|
||||||
raise ChunkFullError
|
|
||||||
|
|
||||||
# set tensor state
|
|
||||||
tensor_state = TensorState.FREE
|
|
||||||
|
|
||||||
# if the process owns the rank, then copy the tensor to its chunk buffer
|
|
||||||
# otherwise set its storage size to 0 to reduce memory consumption
|
|
||||||
if self.is_src_rank:
|
|
||||||
self._payload[self.utilized_size:new_utilized_size].copy_(tensor.flatten())
|
|
||||||
tensor_state = TensorState.HOLD
|
|
||||||
assert type(self._payload) == torch.Tensor, "copy_tensor_to_chunk_slice must use a torch tensor"
|
|
||||||
tensor.data = self._payload[self.utilized_size:new_utilized_size].view(tensor.shape)
|
|
||||||
else:
|
|
||||||
tensor.storage().resize_(0)
|
|
||||||
self.tensors_info[tensor] = TensorInfo(tensor_state, self.utilized_size, new_utilized_size)
|
|
||||||
self.utilized_size = new_utilized_size
|
|
||||||
|
|
||||||
def release(self) -> None:
|
|
||||||
"""
|
|
||||||
Release the memory space on processes which do not own the chunk.
|
|
||||||
"""
|
|
||||||
if not self.is_src_rank:
|
|
||||||
free_storage(self._payload)
|
|
||||||
self._update_tensors_state(TensorState.FREE)
|
|
||||||
|
|
||||||
def _update_tensors_ptr(self) -> None:
|
|
||||||
assert type(self._payload) == torch.Tensor
|
|
||||||
for tensor, tensor_info in self.tensors_info.items():
|
|
||||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
|
||||||
|
|
||||||
def _update_tensors_state(self, next_state: TensorState, prev_state: Optional[TensorState] = None):
|
|
||||||
for tensor_info in self.tensors_info.values():
|
|
||||||
if prev_state is None or tensor_info.state == prev_state:
|
|
||||||
tensor_info.state = next_state
|
|
||||||
|
|
||||||
def access(self) -> None:
|
|
||||||
"""
|
|
||||||
Broadcast the chunk to synchronize the tensors across data parallel processes.
|
|
||||||
"""
|
|
||||||
# recover the chunk on non-owner processes
|
|
||||||
# and broadcast the chunk from the source to all processes
|
|
||||||
if not self.is_src_rank:
|
|
||||||
alloc_storage(self._payload)
|
|
||||||
self.move_device(get_current_device(), update_ptr=False)
|
|
||||||
dist.broadcast(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
|
||||||
|
|
||||||
# update tensor meta info
|
|
||||||
self._update_tensors_ptr()
|
|
||||||
if not self.is_src_rank:
|
|
||||||
self._update_tensors_state(TensorState.HOLD, prev_state=TensorState.FREE)
|
|
||||||
|
|
||||||
def move_device(self, device: torch.device, update_ptr: bool = True) -> None:
|
|
||||||
"""
|
|
||||||
Move the chunk to a target device.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
device (torch.device): the target device for data movement.
|
|
||||||
"""
|
|
||||||
if self._payload.device == device:
|
|
||||||
return
|
|
||||||
if self._cpu_data is None:
|
|
||||||
self.data.data = self.data.to(device)
|
|
||||||
else:
|
|
||||||
if device.type == 'cuda':
|
|
||||||
# cpu -> cuda
|
|
||||||
src = self._cpu_data
|
|
||||||
dest = self.data
|
|
||||||
else:
|
|
||||||
# cuda -> cpu
|
|
||||||
src = self.data
|
|
||||||
dest = self._cpu_data
|
|
||||||
alloc_storage(dest)
|
|
||||||
dest.copy_(src)
|
|
||||||
free_storage(src)
|
|
||||||
|
|
||||||
if update_ptr:
|
|
||||||
self._update_tensors_ptr()
|
|
||||||
|
|
||||||
def reduce(self, is_all_reduce: bool = False) -> None:
|
|
||||||
"""
|
|
||||||
Reduce or all-reduce the chunk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
is_all_reduce (bool): optional, whether to all-reduce the chunk. The default is false.
|
|
||||||
"""
|
|
||||||
self.move_device(get_current_device(), update_ptr=False)
|
|
||||||
if is_all_reduce:
|
|
||||||
dist.all_reduce(self.data, group=gpc.get_group(ParallelMode.DATA))
|
|
||||||
else:
|
|
||||||
dist.reduce(self.data, self.global_src_rank, group=gpc.get_group(ParallelMode.DATA))
|
|
||||||
self._update_tensors_ptr()
|
|
||||||
self._update_tensors_state(TensorState.HOLD)
|
|
||||||
|
|
||||||
def tensor_trans_state(self, tensor: torch.Tensor, tensor_state: TensorState) -> None:
|
|
||||||
"""
|
|
||||||
Make a transition of the tensor into the next state.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (torch.Tensor): a torch Tensor object.
|
|
||||||
tensor_state (TensorState): the target state for transition.
|
|
||||||
"""
|
|
||||||
assert tensor != TensorState.FREE, 'Can only set a chunk of tensors to FREE'
|
|
||||||
# As the gradient hook can be triggered either before or after post-backward
|
|
||||||
# tensor's state can be compute -> hold_after_bwd -> ready_for_reduce
|
|
||||||
# or compute -> ready_for_reduce -> hold_after_bwd
|
|
||||||
# the second one is invalid, we just ignore ready_for_reduce -> hold_after_bwd
|
|
||||||
# this function only apply valid state transformation
|
|
||||||
# invalid calls will be ignored and nothing changes
|
|
||||||
if (self.tensors_info[tensor].state, tensor_state) not in STATE_TRANS:
|
|
||||||
# print(
|
|
||||||
# f'WARNING: Rank{gpc.get_global_rank()} apply invalid state trans: {self.tensors_info[tensor].state} to {tensor_state}'
|
|
||||||
# )
|
|
||||||
return
|
|
||||||
self.tensors_info[tensor].state = tensor_state
|
|
||||||
|
|
||||||
def copy_tensor_to_chunk_slice(self, tensor: torch.Tensor, data_slice: torch.Tensor) -> None:
|
|
||||||
"""
|
|
||||||
Copy data slice to the memory space indexed by the input tensor in the chunk.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor (torch.Tensor): the tensor used to retrive meta information
|
|
||||||
data_slice (torch.Tensor): the tensor to be copied to the chunk
|
|
||||||
"""
|
|
||||||
tensor_info = self.tensors_info[tensor]
|
|
||||||
self._payload[tensor_info.offset:tensor_info.end].copy_(data_slice.flatten())
|
|
||||||
tensor.data = self._payload[tensor_info.offset:tensor_info.end].view(tensor.shape)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_release(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check whether the chunk can be released.
|
|
||||||
"""
|
|
||||||
for tensor_info in self.tensors_info.values():
|
|
||||||
if tensor_info.state != TensorState.HOLD:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_move_device(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check whether the chunk can be moved across devices.
|
|
||||||
"""
|
|
||||||
for tensor_info in self.tensors_info.values():
|
|
||||||
if tensor_info.state in (TensorState.COMPUTE, TensorState.READY_FOR_REDUCE):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def can_reduce(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check whether the chunk can be reduced.
|
|
||||||
"""
|
|
||||||
for tensor_info in self.tensors_info.values():
|
|
||||||
if tensor_info.state != TensorState.READY_FOR_REDUCE:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_empty(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check whether the chunk is empty.
|
|
||||||
"""
|
|
||||||
return is_storage_empty(self._payload)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return f'Chunk: src rank={self.src_rank} ,size={self.size}, utilization={self.utilized_size/self.size*100:.2f}%, freed={self.is_empty}, tensor states={[info.state.name for info in self.tensors_info.values()]}'
|
|
||||||
|
|
||||||
@property
|
|
||||||
def has_inf_or_nan(self) -> bool:
|
|
||||||
"""
|
|
||||||
Check if the chunk has inf or nan values.
|
|
||||||
"""
|
|
||||||
return torch.isinf(self._payload[:self.utilized_size]).any().item() or \
|
|
||||||
torch.isnan(self._payload[:self.utilized_size]).any().item()
|
|
||||||
|
|
||||||
def copy_(self, dest_chunk: 'Chunk'):
|
|
||||||
"""
|
|
||||||
Copy the data of this chunk to a destination chunk.
|
|
||||||
"""
|
|
||||||
assert not self.is_empty
|
|
||||||
assert not dest_chunk.is_empty
|
|
||||||
assert self.size == dest_chunk.size
|
|
||||||
assert self.utilized_size == dest_chunk.utilized_size
|
|
||||||
self._payload.copy_(dest_chunk._payload)
|
|
||||||
self._update_tensors_ptr()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device_type(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the device type of the chunk.
|
|
||||||
"""
|
|
||||||
return self._payload.device.type
|
|
||||||
|
|
||||||
def __hash__(self) -> int:
|
|
||||||
return hash(id(self))
|
|
||||||
|
|
||||||
def __eq__(self, __o: object) -> bool:
|
|
||||||
return self is __o
|
|
||||||
|
|
||||||
def get_tensors(self) -> List[torch.Tensor]:
|
|
||||||
return list(self.tensors_info.keys())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def _payload(self) -> torch.Tensor:
|
|
||||||
if self._cpu_data is None or is_storage_empty(self._cpu_data):
|
|
||||||
return self.data
|
|
||||||
return self._cpu_data
|
|
||||||
|
|
||||||
|
|
||||||
class ChunkManager:
|
class ChunkManager:
|
|
@ -3,8 +3,8 @@ import functools
|
||||||
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
from .memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
from time import time
|
from time import time
|
||||||
from colossalai.tensor.chunk import Chunk, ChunkManager
|
from colossalai.gemini import Chunk, ChunkManager
|
||||||
from .placement_policy import PlacementPolicy, PlacementPolicyFactory
|
from .placement_policy import PlacementPolicyFactory
|
||||||
|
|
||||||
|
|
||||||
class GeminiManager:
|
class GeminiManager:
|
||||||
|
|
|
@ -2,7 +2,7 @@ from colossalai.gemini.memory_tracer import SyncCudaMemoryMonitor
|
||||||
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_used, colo_device_memory_capacity
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from colossalai.tensor import ChunkManager
|
from colossalai.gemini import ChunkManager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import time
|
import time
|
||||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
|
from colossalai.gemini.memory_tracer.memstats_collector import MemStatsCollectorV2
|
||||||
from typing import Type
|
from typing import Type
|
||||||
import functools
|
import functools
|
||||||
from colossalai.tensor.chunk import Chunk, ChunkManager
|
from colossalai.gemini import Chunk, ChunkManager
|
||||||
|
|
||||||
|
|
||||||
class PlacementPolicy(ABC):
|
class PlacementPolicy(ABC):
|
||||||
|
|
|
@ -5,7 +5,7 @@ from colossalai.core import global_context as gpc
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
from colossalai.zero.utils.zero_hook_v2 import ZeROHookV2
|
||||||
from colossalai.tensor.chunk import TensorState, Chunk
|
from colossalai.gemini.chunk import TensorState, Chunk
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
from colossalai.tensor.param_op_hook import ParamOpHookManager
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
from typing import Dict, Iterable, List, Optional
|
from typing import Dict, Iterable, List, Optional
|
||||||
|
|
|
@ -5,7 +5,6 @@ from .colo_parameter import ColoParameter
|
||||||
from .utils import convert_parameter, named_params_with_colotensor
|
from .utils import convert_parameter, named_params_with_colotensor
|
||||||
from .dist_spec_mgr import DistSpecManager
|
from .dist_spec_mgr import DistSpecManager
|
||||||
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
from .param_op_hook import ParamOpHook, ParamOpHookManager
|
||||||
from .chunk import ChunkManager, TensorState
|
|
||||||
from . import distspec
|
from . import distspec
|
||||||
from .process_group import ProcessGroup
|
from .process_group import ProcessGroup
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor.param_op_hook import ParamOpHook
|
from colossalai.tensor.param_op_hook import ParamOpHook
|
||||||
from colossalai.tensor.chunk import ChunkManager, TensorState
|
from colossalai.gemini import TensorState
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ChunkManager
|
from colossalai.gemini import ChunkManager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
from colossalai.nn.parallel import ColoDDP, ZeroDDP
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
from colossalai.gemini.gemini_mgr import GeminiManager
|
||||||
|
|
|
@ -6,7 +6,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ChunkManager
|
from colossalai.gemini import ChunkManager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
||||||
|
|
|
@ -5,14 +5,7 @@ import torch.multiprocessing as mp
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
from colossalai.tensor import ChunkManager
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
|
||||||
from colossalai.nn.parallel import ZeroDDP, ColoDDP
|
|
||||||
from colossalai.gemini.gemini_mgr import GeminiManager
|
|
||||||
from typing import Callable
|
|
||||||
from collections import OrderedDict
|
|
||||||
from colossalai.nn.parallel.reducer import Reducer
|
from colossalai.nn.parallel.reducer import Reducer
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch.distributed.distributed_c10d import _get_default_group
|
from torch.distributed.distributed_c10d import _get_default_group
|
||||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from typing import List
|
from typing import List
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.tensor import ChunkManager
|
from colossalai.gemini import ChunkManager
|
||||||
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
from colossalai.testing import rerun_if_address_is_in_use, parameterize
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ChunkManager
|
from colossalai.gemini import ChunkManager
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
from _utils import tensor_equal, set_seed, tensor_shard_equal
|
||||||
|
|
|
@ -7,13 +7,12 @@ from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.tensor import ChunkManager
|
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from tests.test_tensor._utils import set_seed
|
from tests.test_tensor._utils import set_seed
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
from colossalai.nn.parallel.data_parallel import ZeroDDP
|
||||||
from colossalai.gemini import GeminiManager
|
from colossalai.gemini import ChunkManager, GeminiManager
|
||||||
from colossalai.testing import parameterize
|
from colossalai.testing import parameterize
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
from colossalai.zero import ZeroOptimizer
|
from colossalai.zero import ZeroOptimizer
|
||||||
|
|
Loading…
Reference in New Issue