[gemini] add GeminiMemoryManger (#832)

* refactor StatefulTensor, tensor utilities

* add unitest for GeminiMemoryManager
pull/837/head
HELSON 2022-04-24 13:08:48 +08:00 committed by GitHub
parent 35ea6e1023
commit e5ea3fdeef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 414 additions and 180 deletions

View File

@ -0,0 +1,45 @@
from enum import EnumMeta
class GeminiMemoryManager(object):
def __init__(self, states_cls: EnumMeta):
super().__init__()
self.states_cls = states_cls
self._cnter = 0 # the counter of instances
self.total_mem = dict()
self.state_mem = dict()
self.state_mem['cpu'] = dict()
self.state_mem['cuda'] = dict()
self.reset()
@property
def total_number(self):
return self._cnter
def reset(self):
self._cnter = 0 # the counter of instances
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
# memory conditions for all states
for state in self.states_cls:
self.state_mem['cpu'][state] = 0
self.state_mem['cuda'][state] = 0
def register_new_instance(self):
self._cnter += 1
def print_info(self):
print(
f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')
for state in self.states_cls:
print(
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')

View File

@ -0,0 +1,204 @@
from enum import Enum
from typing import Optional
import torch
from typing import Union
from colossalai.gemini.gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor):
return tensor.numel() * tensor.element_size()
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
# Global Stateful Tensor Manager
GST_MGR = GeminiMemoryManager(TensorState)
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = None
self._payload_size = 0 # byte size of current payload
StatefulTensor.GST_MGR.register_new_instance()
if self._state == TensorState.FREE:
# when the state is free, payload should be None
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
else:
# otherwise, payload should not be None
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
self._payload = maybe_tensor
self._payload_size = sizeof_tensor(maybe_tensor)
self.__trans_state_update(TensorState.FREE, state)
def data_ptr(self):
if self._payload is None:
return 0 # if a tensor has no storage, 0 should be returned
return self._payload.data_ptr()
def set_null(self) -> None:
# notice that free stateful tensor do not need to become null again
if self.state != TensorState.FREE:
self.__trans_state_update(self.state, TensorState.FREE)
self.__release()
def is_null(self) -> bool:
if self.state == TensorState.FREE:
# check sanity here
assert self.payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
if self.state == TensorState.FREE:
# free stateful tensor can't change state
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
return
self.__trans_state_update(self.state, state)
if state == TensorState.FREE:
self.__release()
else:
self._state = state
def move_to(self, device: Union[torch.device, int]):
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
if not isinstance(device, torch.device):
to_device = torch.device('cuda', device)
else:
to_device = device
from_device_type = self.device.type
if from_device_type == to_device.type:
# from device == to device
return
# update manager's information
self.__trans_device_update(from_device_type, to_device.type)
self.payload.data = self.payload.data.to(to_device)
def payload_copy(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def payload_reset(self, tensor) -> None:
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
if self.payload is not None:
# release old payload
self.__trans_state_update(self.state, TensorState.FREE)
else:
# otherwise, set the state to HOLD for new payload
self._state = TensorState.HOLD
del self._payload
self._payload = tensor
self._payload_size = sizeof_tensor(tensor)
# record new payload
self.__trans_state_update(TensorState.FREE, self.state)
def payload_relay(self, rhs):
# relay the payload of rhs to current stateful tensor
# can't support null relay right now
assert not rhs.is_null()
# now this function only support stateful tensor that has zero-length payload
# because it doesn't require memory manager updating
# you can extend this function by yourself
assert self.payload_size == 0
self._payload = rhs.payload
self._payload_size = rhs.payload_size
self._state = TensorState.HOLD
self.__trans_state_update(rhs.state, TensorState.HOLD)
rhs.__release()
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
@property
def payload_size(self) -> int:
return self._payload_size
@property
def state(self) -> TensorState:
return self._state
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
def __release(self):
# release current payload
# shouldn't be visible to users
self._state = TensorState.FREE
self._payload = None
self._payload_size = 0
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
"""Update global manager when changing the state of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
device_type = self.device.type
if from_state != TensorState.FREE:
manager.state_mem[device_type][from_state] -= size
else:
# when from_state is FREE, the tensor is new to manager
# we should add its memory
manager.total_mem[device_type] += size
if to_state != TensorState.FREE:
manager.state_mem[device_type][to_state] += size
else:
# when to_state is FREE, the tensor will be deleted soon
# we should sub its memory
manager.total_mem[device_type] -= size
def __trans_device_update(self, from_type: str, to_type: str):
"""Update global manager when changing the device of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
state = self.state
# update aggregated information
manager.total_mem[from_type] -= size
manager.total_mem[to_type] += size
# update the information of each state
manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size

View File

@ -2,9 +2,8 @@ import functools
import torch
import types
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List
from colossalai.logging import get_dist_logger
@ -30,7 +29,8 @@ class StatefulTensorMgr(object):
self._cpu_gpu_move_volume = 0
def register_stateful_param(self, param: ShardedParamV2) -> None:
def register_stateful_param(self, param) -> None:
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
assert isinstance(param, ShardedParamV2)
for t in param.get_payload_tensors():
assert isinstance(t, StatefulTensor)

View File

@ -4,8 +4,8 @@ import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from typing import Type

View File

@ -1,10 +1,10 @@
import torch
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if issubclass(type(tensor), StatefulTensor):
if isinstance(tensor, StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
@ -24,23 +24,24 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None:
"""
A colossal API for model data tensor move.
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if issubclass(type(src_t), StatefulTensor):
if isinstance(src_t, StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if issubclass(type(tgt_t), StatefulTensor):
if isinstance(tgt_t, StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
@ -48,70 +49,56 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if issubclass(type(src_t), StatefulTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
if isinstance(src_t, StatefulTensor):
src_t.set_null()
else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
int]) -> None:
"""
"""
move a tensor to the target_device
Args:
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a traget device, if type is int, it the index of cuda card.
"""
if isinstance(t, torch.Tensor):
t_payload = t
elif issubclass(type(t), StatefulTensor):
t_payload = t.payload
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
t_payload.data = t_payload.data.to(target_device)
if isinstance(t, torch.Tensor):
t.data = t.data.to(target_device)
elif isinstance(t, StatefulTensor):
t.move_to(target_device)
else:
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
if issubclass(type(t), StatefulTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if t_payload.device.type == 'cpu':
return
# TODO() optimize the tensor moving with non-blocking
t_payload.data = t_payload.data.cpu()
if isinstance(t, torch.Tensor):
t.data = t.data.cpu()
elif isinstance(t, StatefulTensor):
t.move_to(torch.device('cpu'))
else:
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
ret = t_payload.to(target_device)
return ret
# TODO() rename this function
colo_model_data_tensor_move_inline(t, target_device)
t_payload = t.payload if isinstance(t, StatefulTensor) else t
return t_payload

View File

@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data
return self.data.clone()
class NormalNoiseGenerator:

View File

@ -35,4 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return zero_model, zero_optimizer
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']

View File

@ -184,11 +184,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.colo_attr = ShardedParamV2(param, set_data_none=False)
param.colo_attr = ShardedParamV2(param, set_data_none=True)
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.data_payload # set param.data to payload
param.data = param.colo_attr.data_payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated

View File

@ -31,9 +31,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
@ -44,6 +41,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.reset_payload(gathered_payload)
t.payload_reset(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]

View File

@ -3,10 +3,10 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
class TensorShardStrategy(BaseShardStrategy):
@ -36,7 +36,7 @@ class TensorShardStrategy(BaseShardStrategy):
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload)
t.payload_reset(sharded_payload)
t.is_sharded = True
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
@ -53,6 +53,6 @@ class TensorShardStrategy(BaseShardStrategy):
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
t.payload_reset(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device)
t.is_sharded = False

View File

@ -3,7 +3,7 @@ from typing import Any, Callable, List, Tuple
import torch
import torch.nn.functional as F
from typing import Union
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float:

View File

@ -17,11 +17,11 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
@ -358,8 +358,11 @@ class ShardedModelV2(nn.Module):
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.reset_grad_payload(grad.data)
param.colo_attr.reset_data_payload(grad.data) # release the memory of param
param.colo_attr.grad_payload_reset(grad.data)
# release the memory of param
# we set a false None for parameter's payload
# so we can get paramter's device and dtype later in optimizer
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
@ -368,7 +371,7 @@ class ShardedModelV2(nn.Module):
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.reset_grad_payload(fp32_grad)
param.colo_attr.grad_payload_reset(fp32_grad)
else:
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))

View File

@ -12,15 +12,15 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
@ -253,7 +253,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for p in group['params']:
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
if recover_data and reuse_fp16_shard:
self._copy_master_param_to_param_fp16(p)
else:
@ -332,12 +332,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _copy_master_param_to_param_fp16(self, p):
# flush gradient
p.colo_attr.saved_grad.set_null()
if p.colo_attr.sharded_data_tensor.payload_size == 0:
# here reuse_fp16_shard is True
# in order to use copy below, we should give sharded data tensor a payload
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
else:
p.colo_attr.saved_grad.set_null()
p.data = self.master_params[p].payload
# we need to allocate new memory for keep_not_shard paramters
# in order to use copy, otherwise, the sizes of tensor is not compatible
if p.colo_attr.data_payload.numel() != p.data.numel():
p.colo_attr.data_payload_reset(
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.reset_data_payload(
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:

View File

@ -1,11 +1,5 @@
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
colo_model_data_move_to_cpu, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
__all__ = [
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
]
'ShardedTensor', 'ShardedParamV2']

View File

@ -1,8 +1,8 @@
import torch
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from typing import List
EMPTY_TENSOR_DICT = {}
@ -50,6 +50,7 @@ class ShardedParamV2(object):
@property
def data_payload(self):
assert not self.sharded_data_tensor.is_null()
return self.sharded_data_tensor.payload
@property
@ -61,15 +62,15 @@ class ShardedParamV2(object):
def param_is_sharded(self):
return self.sharded_data_tensor.is_sharded
def reset_data_payload(self, tensor: torch.Tensor):
def data_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.reset_payload(tensor)
self.sharded_data_tensor.payload_reset(tensor)
def reset_grad_payload(self, tensor: torch.Tensor):
def grad_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.saved_grad.reset_payload(tensor)
self.saved_grad.payload_reset(tensor)
def get_memory_usage(self) -> Tuple[int, int]:
"""

View File

@ -1,6 +1,5 @@
import torch
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from typing import Optional
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor):

View File

@ -1,80 +0,0 @@
from enum import Enum
from typing import Optional
import torch
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self._payload is None:
return None
return self._payload.data_ptr()
@property
def state(self) -> TensorState:
return self._state
def set_null(self) -> None:
self._state = TensorState.FREE
self._payload = None
def is_null(self) -> bool:
if self._state == TensorState.FREE:
assert self._payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
self._payload = None
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
def copy_payload(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> None:
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")

View File

@ -8,12 +8,11 @@ from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector
from typing import Any
from colossalai.gemini.stateful_tensor import TensorState
@OPHOOKS.register_module

View File

@ -0,0 +1,73 @@
import pytest
import torch
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
@pytest.mark.dist
def test_gemini_manager():
# reset the manager, in case that there exists memory information left
manager = StatefulTensor.GST_MGR
manager.reset()
# occupation 8
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
# occupation 60
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
# occupation 28
t1 = torch.empty(7, device='cuda')
# occupation 12
t2 = torch.empty(3, device='cpu')
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
st4 = StatefulTensor(None, TensorState.FREE)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 60
assert manager.total_mem['cuda'] == 36
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
st4.payload_reset(t2)
st3.payload_reset(t2)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 84
assert manager.total_mem['cuda'] == 8
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
st1.move_to(torch.device('cpu'))
st2.move_to(torch.device('cpu'))
st3.move_to(torch.device('cuda', 0))
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
st1.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.HOLD_AFTER_BWD)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
if __name__ == '__main__':
test_gemini_manager()

View File

@ -6,9 +6,8 @@ from colossalai.utils.cuda import get_current_device
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_set_process_memory_fraction
from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from torch.nn.parameter import Parameter

View File

@ -1,7 +1,7 @@
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
import colossalai
import torch

View File

@ -11,7 +11,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.stateful_tensor import StatefulTensor
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])

View File

@ -2,9 +2,10 @@ import pytest
import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use