[gemini] add GeminiMemoryManger (#832)

* refactor StatefulTensor, tensor utilities

* add unitest for GeminiMemoryManager
pull/837/head
HELSON 2022-04-24 13:08:48 +08:00 committed by GitHub
parent 35ea6e1023
commit e5ea3fdeef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 414 additions and 180 deletions

View File

@ -0,0 +1,45 @@
from enum import EnumMeta
class GeminiMemoryManager(object):
def __init__(self, states_cls: EnumMeta):
super().__init__()
self.states_cls = states_cls
self._cnter = 0 # the counter of instances
self.total_mem = dict()
self.state_mem = dict()
self.state_mem['cpu'] = dict()
self.state_mem['cuda'] = dict()
self.reset()
@property
def total_number(self):
return self._cnter
def reset(self):
self._cnter = 0 # the counter of instances
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
# memory conditions for all states
for state in self.states_cls:
self.state_mem['cpu'][state] = 0
self.state_mem['cuda'][state] = 0
def register_new_instance(self):
self._cnter += 1
def print_info(self):
print(
f"Total number: {self.total_number}",
f"Total CPU memory occupation: {self.total_mem['cpu']}",
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')
for state in self.states_cls:
print(
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')

View File

@ -0,0 +1,204 @@
from enum import Enum
from typing import Optional
import torch
from typing import Union
from colossalai.gemini.gemini_context import GeminiMemoryManager
def sizeof_tensor(tensor: torch.Tensor):
return tensor.numel() * tensor.element_size()
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
# Global Stateful Tensor Manager
GST_MGR = GeminiMemoryManager(TensorState)
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = None
self._payload_size = 0 # byte size of current payload
StatefulTensor.GST_MGR.register_new_instance()
if self._state == TensorState.FREE:
# when the state is free, payload should be None
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
else:
# otherwise, payload should not be None
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
self._payload = maybe_tensor
self._payload_size = sizeof_tensor(maybe_tensor)
self.__trans_state_update(TensorState.FREE, state)
def data_ptr(self):
if self._payload is None:
return 0 # if a tensor has no storage, 0 should be returned
return self._payload.data_ptr()
def set_null(self) -> None:
# notice that free stateful tensor do not need to become null again
if self.state != TensorState.FREE:
self.__trans_state_update(self.state, TensorState.FREE)
self.__release()
def is_null(self) -> bool:
if self.state == TensorState.FREE:
# check sanity here
assert self.payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
if self.state == TensorState.FREE:
# free stateful tensor can't change state
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
return
self.__trans_state_update(self.state, state)
if state == TensorState.FREE:
self.__release()
else:
self._state = state
def move_to(self, device: Union[torch.device, int]):
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
if not isinstance(device, torch.device):
to_device = torch.device('cuda', device)
else:
to_device = device
from_device_type = self.device.type
if from_device_type == to_device.type:
# from device == to device
return
# update manager's information
self.__trans_device_update(from_device_type, to_device.type)
self.payload.data = self.payload.data.to(to_device)
def payload_copy(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def payload_reset(self, tensor) -> None:
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
if self.payload is not None:
# release old payload
self.__trans_state_update(self.state, TensorState.FREE)
else:
# otherwise, set the state to HOLD for new payload
self._state = TensorState.HOLD
del self._payload
self._payload = tensor
self._payload_size = sizeof_tensor(tensor)
# record new payload
self.__trans_state_update(TensorState.FREE, self.state)
def payload_relay(self, rhs):
# relay the payload of rhs to current stateful tensor
# can't support null relay right now
assert not rhs.is_null()
# now this function only support stateful tensor that has zero-length payload
# because it doesn't require memory manager updating
# you can extend this function by yourself
assert self.payload_size == 0
self._payload = rhs.payload
self._payload_size = rhs.payload_size
self._state = TensorState.HOLD
self.__trans_state_update(rhs.state, TensorState.HOLD)
rhs.__release()
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
@property
def payload_size(self) -> int:
return self._payload_size
@property
def state(self) -> TensorState:
return self._state
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
def __release(self):
# release current payload
# shouldn't be visible to users
self._state = TensorState.FREE
self._payload = None
self._payload_size = 0
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
"""Update global manager when changing the state of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
device_type = self.device.type
if from_state != TensorState.FREE:
manager.state_mem[device_type][from_state] -= size
else:
# when from_state is FREE, the tensor is new to manager
# we should add its memory
manager.total_mem[device_type] += size
if to_state != TensorState.FREE:
manager.state_mem[device_type][to_state] += size
else:
# when to_state is FREE, the tensor will be deleted soon
# we should sub its memory
manager.total_mem[device_type] -= size
def __trans_device_update(self, from_type: str, to_type: str):
"""Update global manager when changing the device of a tensor
"""
manager = StatefulTensor.GST_MGR
size = self.payload_size
state = self.state
# update aggregated information
manager.total_mem[from_type] -= size
manager.total_mem[to_type] += size
# update the information of each state
manager.state_mem[from_type][state] -= size
manager.state_mem[to_type][state] += size

View File

@ -2,9 +2,8 @@ import functools
import torch
import types
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
from typing import List
from colossalai.logging import get_dist_logger
@ -30,7 +29,8 @@ class StatefulTensorMgr(object):
self._cpu_gpu_move_volume = 0
def register_stateful_param(self, param: ShardedParamV2) -> None:
def register_stateful_param(self, param) -> None:
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
assert isinstance(param, ShardedParamV2)
for t in param.get_payload_tensors():
assert isinstance(t, StatefulTensor)

View File

@ -4,8 +4,8 @@ import torch
from colossalai.utils import get_current_device
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from typing import Type

View File

@ -1,10 +1,10 @@
import torch
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.stateful_tensor import StatefulTensor
from typing import Union, Tuple
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if issubclass(type(tensor), StatefulTensor):
if isinstance(tensor, StatefulTensor):
t = tensor.payload
elif isinstance(tensor, torch.Tensor):
t = tensor
@ -32,15 +32,16 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
"""
if issubclass(type(src_t), StatefulTensor):
if isinstance(src_t, StatefulTensor):
src_t_payload = src_t.payload
else:
src_t_payload = src_t.data
src_dev = src_t_payload.device
if issubclass(type(tgt_t), StatefulTensor):
if isinstance(tgt_t, StatefulTensor):
tgt_t_payload = tgt_t.payload
else:
tgt_t_payload = tgt_t.data
@ -48,10 +49,10 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t
if issubclass(type(src_t), StatefulTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
if isinstance(src_t, StatefulTensor):
src_t.set_null()
else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
@ -62,56 +63,42 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
target_device: a traget device, if type is int, it the index of cuda card.
"""
if isinstance(t, torch.Tensor):
t_payload = t
elif issubclass(type(t), StatefulTensor):
t_payload = t.payload
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if not isinstance(target_device, torch.device):
target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:
return
t_payload.data = t_payload.data.to(target_device)
if isinstance(t, torch.Tensor):
t.data = t.data.to(target_device)
elif isinstance(t, StatefulTensor):
t.move_to(target_device)
else:
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu
Args:
t (Union[StatefulTensor, torch.Tensor]): _description_
"""
if issubclass(type(t), StatefulTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t
else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if t_payload.device.type == 'cpu':
return
# TODO() optimize the tensor moving with non-blocking
t_payload.data = t_payload.data.cpu()
if isinstance(t, torch.Tensor):
t.data = t.data.cpu()
elif isinstance(t, StatefulTensor):
t.move_to(torch.device('cpu'))
else:
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
"""
Clone a model data tensor
Args:
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device
Returns:
torch.Tensor: a cloned torch tensor
"""
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
ret = t_payload.to(target_device)
return ret
# TODO() rename this function
colo_model_data_tensor_move_inline(t, target_device)
t_payload = t.payload if isinstance(t, StatefulTensor) else t
return t_payload

View File

@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
class ForceFP32Parameter(torch.nn.Parameter):
def half(self, memory_format=None):
return self.data
return self.data.clone()
class NormalNoiseGenerator:

View File

@ -35,4 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
return zero_model, zero_optimizer
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']

View File

@ -184,11 +184,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
if param.grad is not None:
param.grad = param.grad.to(target_device)
param.colo_attr = ShardedParamV2(param, set_data_none=False)
param.colo_attr = ShardedParamV2(param, set_data_none=True)
if self.shard_param:
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
param.data = param.colo_attr.data_payload # set param.data to payload
param.data = param.colo_attr.data_payload # set param.data to payload
# mark whether the param is replicated
param.colo_attr.is_replicated = self.is_replicated

View File

@ -31,9 +31,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
for i in range(world_size):
if i == rank:
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
# Release payload here, to decrease peak memory usage
for t in tensor_list:
t.reset_payload(None)
else:
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
@ -44,6 +41,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
for i, t in enumerate(tensor_list):
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
t.reset_payload(gathered_payload)
t.payload_reset(gathered_payload)
t.is_sharded = False
offset += tensor_numels[i]

View File

@ -3,10 +3,10 @@ from typing import List, Optional
import torch
import torch.distributed as dist
from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.shard_utils.commons import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
class TensorShardStrategy(BaseShardStrategy):
@ -36,7 +36,7 @@ class TensorShardStrategy(BaseShardStrategy):
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
f" but current cuda device is {get_current_device()}"
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
t.reset_payload(sharded_payload)
t.payload_reset(sharded_payload)
t.is_sharded = True
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
@ -53,6 +53,6 @@ class TensorShardStrategy(BaseShardStrategy):
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
t.reset_payload(gathered_payload)
t.payload_reset(gathered_payload)
colo_model_data_tensor_move_inline(t, target_device)
t.is_sharded = False

View File

@ -3,7 +3,7 @@ from typing import Any, Callable, List, Tuple
import torch
import torch.nn.functional as F
from typing import Union
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.stateful_tensor import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float:

View File

@ -17,11 +17,11 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_device_memory_capacity
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
@ -358,8 +358,11 @@ class ShardedModelV2(nn.Module):
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.reset_grad_payload(grad.data)
param.colo_attr.reset_data_payload(grad.data) # release the memory of param
param.colo_attr.grad_payload_reset(grad.data)
# release the memory of param
# we set a false None for parameter's payload
# so we can get paramter's device and dtype later in optimizer
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
if param.colo_attr.is_replicated:
param.colo_attr.sharded_data_tensor.is_sharded = True
@ -368,7 +371,7 @@ class ShardedModelV2(nn.Module):
fp32_grad = cast_tensor_to_fp32(grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.reset_grad_payload(fp32_grad)
param.colo_attr.grad_payload_reset(fp32_grad)
else:
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))

View File

@ -12,15 +12,15 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.gemini.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
@ -253,7 +253,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
for p in group['params']:
# p.colo_attr.sharded_data_tensor stores grad now
# we have to recover fp16 param
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
if recover_data and reuse_fp16_shard:
self._copy_master_param_to_param_fp16(p)
else:
@ -332,12 +332,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _copy_master_param_to_param_fp16(self, p):
# flush gradient
p.colo_attr.saved_grad.set_null()
if p.colo_attr.sharded_data_tensor.payload_size == 0:
# here reuse_fp16_shard is True
# in order to use copy below, we should give sharded data tensor a payload
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
else:
p.colo_attr.saved_grad.set_null()
p.data = self.master_params[p].payload
# we need to allocate new memory for keep_not_shard paramters
# in order to use copy, otherwise, the sizes of tensor is not compatible
if p.colo_attr.data_payload.numel() != p.data.numel():
p.colo_attr.data_payload_reset(
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.data = self.master_params[p].payload
p.colo_attr.reset_data_payload(
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
p.colo_attr.set_data_none()
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:

View File

@ -1,11 +1,5 @@
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
colo_model_data_move_to_cpu, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
__all__ = [
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
]
'ShardedTensor', 'ShardedParamV2']

View File

@ -1,8 +1,8 @@
import torch
from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
from typing import List
EMPTY_TENSOR_DICT = {}
@ -50,6 +50,7 @@ class ShardedParamV2(object):
@property
def data_payload(self):
assert not self.sharded_data_tensor.is_null()
return self.sharded_data_tensor.payload
@property
@ -61,15 +62,15 @@ class ShardedParamV2(object):
def param_is_sharded(self):
return self.sharded_data_tensor.is_sharded
def reset_data_payload(self, tensor: torch.Tensor):
def data_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.sharded_data_tensor.reset_payload(tensor)
self.sharded_data_tensor.payload_reset(tensor)
def reset_grad_payload(self, tensor: torch.Tensor):
def grad_payload_reset(self, tensor: torch.Tensor):
assert type(tensor) is torch.Tensor
assert tensor.requires_grad is False
self.saved_grad.reset_payload(tensor)
self.saved_grad.payload_reset(tensor)
def get_memory_usage(self) -> Tuple[int, int]:
"""

View File

@ -1,6 +1,5 @@
import torch
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
from typing import Optional
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
class ShardedTensor(StatefulTensor):

View File

@ -1,80 +0,0 @@
from enum import Enum
from typing import Optional
import torch
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
COMPUTE = 4
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
Inspired from the paper:
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
self._state = state
self._payload = tensor
if self._state == TensorState.FREE:
assert self._payload is None, f"payload has to None if state is {self._state}"
def data_ptr(self):
if self._payload is None:
return None
return self._payload.data_ptr()
@property
def state(self) -> TensorState:
return self._state
def set_null(self) -> None:
self._state = TensorState.FREE
self._payload = None
def is_null(self) -> bool:
if self._state == TensorState.FREE:
assert self._payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
self._payload = None
@property
def payload(self) -> Optional[torch.Tensor]:
return self._payload
def copy_payload(self, tensor) -> None:
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> None:
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
return self._payload.dtype
@property
def shape(self):
return self._payload.shape
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")

View File

@ -8,12 +8,11 @@ from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.engine.ophooks import BaseOpHook
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
from colossalai.gemini.memory_tracer import MemStatsCollector
from typing import Any
from colossalai.gemini.stateful_tensor import TensorState
@OPHOOKS.register_module

View File

@ -0,0 +1,73 @@
import pytest
import torch
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
@pytest.mark.dist
def test_gemini_manager():
# reset the manager, in case that there exists memory information left
manager = StatefulTensor.GST_MGR
manager.reset()
# occupation 8
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
# occupation 60
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
# occupation 28
t1 = torch.empty(7, device='cuda')
# occupation 12
t2 = torch.empty(3, device='cpu')
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
st4 = StatefulTensor(None, TensorState.FREE)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 60
assert manager.total_mem['cuda'] == 36
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
st4.payload_reset(t2)
st3.payload_reset(t2)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 84
assert manager.total_mem['cuda'] == 8
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
st1.move_to(torch.device('cpu'))
st2.move_to(torch.device('cpu'))
st3.move_to(torch.device('cuda', 0))
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
st1.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.COMPUTE)
st2.trans_state(TensorState.HOLD_AFTER_BWD)
assert manager.total_number == 4
assert manager.total_mem['cpu'] == 80
assert manager.total_mem['cuda'] == 12
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
if __name__ == '__main__':
test_gemini_manager()

View File

@ -6,9 +6,8 @@ from colossalai.utils.cuda import get_current_device
from colossalai.gemini.memory_tracer import MemStatsCollector
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory import colo_set_process_memory_fraction
from colossalai.gemini import StatefulTensorMgr
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.zero.sharded_param.tensorful_state import TensorState
from colossalai.gemini.stateful_tensor import TensorState
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from torch.nn.parameter import Parameter

View File

@ -1,7 +1,7 @@
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
import colossalai
import torch

View File

@ -11,7 +11,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from tests.test_zero.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from colossalai.gemini.stateful_tensor import StatefulTensor
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])

View File

@ -2,9 +2,10 @@ import pytest
import colossalai
from colossalai.utils.cuda import get_current_device
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move,
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
colo_model_tensor_clone)
from colossalai.gemini.stateful_tensor import StatefulTensor
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use