mirror of https://github.com/hpcaitech/ColossalAI
[gemini] add GeminiMemoryManger (#832)
* refactor StatefulTensor, tensor utilities * add unitest for GeminiMemoryManagerpull/837/head
parent
35ea6e1023
commit
e5ea3fdeef
|
@ -0,0 +1,45 @@
|
|||
from enum import EnumMeta
|
||||
|
||||
|
||||
class GeminiMemoryManager(object):
|
||||
|
||||
def __init__(self, states_cls: EnumMeta):
|
||||
super().__init__()
|
||||
self.states_cls = states_cls
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem = dict()
|
||||
self.state_mem = dict()
|
||||
self.state_mem['cpu'] = dict()
|
||||
self.state_mem['cuda'] = dict()
|
||||
|
||||
self.reset()
|
||||
|
||||
@property
|
||||
def total_number(self):
|
||||
return self._cnter
|
||||
|
||||
def reset(self):
|
||||
self._cnter = 0 # the counter of instances
|
||||
|
||||
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||
|
||||
# memory conditions for all states
|
||||
for state in self.states_cls:
|
||||
self.state_mem['cpu'][state] = 0
|
||||
self.state_mem['cuda'][state] = 0
|
||||
|
||||
def register_new_instance(self):
|
||||
self._cnter += 1
|
||||
|
||||
def print_info(self):
|
||||
print(
|
||||
f"Total number: {self.total_number}",
|
||||
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')
|
||||
|
||||
for state in self.states_cls:
|
||||
print(
|
||||
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')
|
|
@ -0,0 +1,204 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
import torch
|
||||
from typing import Union
|
||||
|
||||
from colossalai.gemini.gemini_context import GeminiMemoryManager
|
||||
|
||||
|
||||
def sizeof_tensor(tensor: torch.Tensor):
|
||||
return tensor.numel() * tensor.element_size()
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
HOLD = 1
|
||||
HOLD_AFTER_FWD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
COMPUTE = 4
|
||||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
# Global Stateful Tensor Manager
|
||||
GST_MGR = GeminiMemoryManager(TensorState)
|
||||
|
||||
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = None
|
||||
self._payload_size = 0 # byte size of current payload
|
||||
|
||||
StatefulTensor.GST_MGR.register_new_instance()
|
||||
|
||||
if self._state == TensorState.FREE:
|
||||
# when the state is free, payload should be None
|
||||
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
|
||||
else:
|
||||
# otherwise, payload should not be None
|
||||
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
|
||||
self._payload = maybe_tensor
|
||||
self._payload_size = sizeof_tensor(maybe_tensor)
|
||||
self.__trans_state_update(TensorState.FREE, state)
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return 0 # if a tensor has no storage, 0 should be returned
|
||||
return self._payload.data_ptr()
|
||||
|
||||
def set_null(self) -> None:
|
||||
# notice that free stateful tensor do not need to become null again
|
||||
if self.state != TensorState.FREE:
|
||||
self.__trans_state_update(self.state, TensorState.FREE)
|
||||
self.__release()
|
||||
|
||||
def is_null(self) -> bool:
|
||||
if self.state == TensorState.FREE:
|
||||
# check sanity here
|
||||
assert self.payload is None
|
||||
return True
|
||||
return False
|
||||
|
||||
def trans_state(self, state: TensorState) -> None:
|
||||
if self.state == TensorState.FREE:
|
||||
# free stateful tensor can't change state
|
||||
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
|
||||
return
|
||||
|
||||
self.__trans_state_update(self.state, state)
|
||||
|
||||
if state == TensorState.FREE:
|
||||
self.__release()
|
||||
else:
|
||||
self._state = state
|
||||
|
||||
def move_to(self, device: Union[torch.device, int]):
|
||||
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
|
||||
|
||||
if not isinstance(device, torch.device):
|
||||
to_device = torch.device('cuda', device)
|
||||
else:
|
||||
to_device = device
|
||||
|
||||
from_device_type = self.device.type
|
||||
if from_device_type == to_device.type:
|
||||
# from device == to device
|
||||
return
|
||||
|
||||
# update manager's information
|
||||
self.__trans_device_update(from_device_type, to_device.type)
|
||||
self.payload.data = self.payload.data.to(to_device)
|
||||
|
||||
def payload_copy(self, tensor) -> None:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def payload_reset(self, tensor) -> None:
|
||||
|
||||
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
|
||||
|
||||
if self.payload is not None:
|
||||
# release old payload
|
||||
self.__trans_state_update(self.state, TensorState.FREE)
|
||||
else:
|
||||
# otherwise, set the state to HOLD for new payload
|
||||
self._state = TensorState.HOLD
|
||||
del self._payload
|
||||
|
||||
self._payload = tensor
|
||||
self._payload_size = sizeof_tensor(tensor)
|
||||
# record new payload
|
||||
self.__trans_state_update(TensorState.FREE, self.state)
|
||||
|
||||
def payload_relay(self, rhs):
|
||||
# relay the payload of rhs to current stateful tensor
|
||||
# can't support null relay right now
|
||||
assert not rhs.is_null()
|
||||
|
||||
# now this function only support stateful tensor that has zero-length payload
|
||||
# because it doesn't require memory manager updating
|
||||
# you can extend this function by yourself
|
||||
assert self.payload_size == 0
|
||||
|
||||
self._payload = rhs.payload
|
||||
self._payload_size = rhs.payload_size
|
||||
self._state = TensorState.HOLD
|
||||
self.__trans_state_update(rhs.state, TensorState.HOLD)
|
||||
|
||||
rhs.__release()
|
||||
|
||||
@property
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self._payload
|
||||
|
||||
@property
|
||||
def payload_size(self) -> int:
|
||||
return self._payload_size
|
||||
|
||||
@property
|
||||
def state(self) -> TensorState:
|
||||
return self._state
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
|
||||
|
||||
def __release(self):
|
||||
# release current payload
|
||||
# shouldn't be visible to users
|
||||
self._state = TensorState.FREE
|
||||
self._payload = None
|
||||
self._payload_size = 0
|
||||
|
||||
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
|
||||
"""Update global manager when changing the state of a tensor
|
||||
"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
device_type = self.device.type
|
||||
|
||||
if from_state != TensorState.FREE:
|
||||
manager.state_mem[device_type][from_state] -= size
|
||||
else:
|
||||
# when from_state is FREE, the tensor is new to manager
|
||||
# we should add its memory
|
||||
manager.total_mem[device_type] += size
|
||||
|
||||
if to_state != TensorState.FREE:
|
||||
manager.state_mem[device_type][to_state] += size
|
||||
else:
|
||||
# when to_state is FREE, the tensor will be deleted soon
|
||||
# we should sub its memory
|
||||
manager.total_mem[device_type] -= size
|
||||
|
||||
def __trans_device_update(self, from_type: str, to_type: str):
|
||||
"""Update global manager when changing the device of a tensor
|
||||
"""
|
||||
manager = StatefulTensor.GST_MGR
|
||||
size = self.payload_size
|
||||
state = self.state
|
||||
|
||||
# update aggregated information
|
||||
manager.total_mem[from_type] -= size
|
||||
manager.total_mem[to_type] += size
|
||||
|
||||
# update the information of each state
|
||||
manager.state_mem[from_type][state] -= size
|
||||
manager.state_mem[to_type][state] += size
|
|
@ -2,9 +2,8 @@ import functools
|
|||
import torch
|
||||
import types
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
||||
from typing import List
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
@ -30,7 +29,8 @@ class StatefulTensorMgr(object):
|
|||
|
||||
self._cpu_gpu_move_volume = 0
|
||||
|
||||
def register_stateful_param(self, param: ShardedParamV2) -> None:
|
||||
def register_stateful_param(self, param) -> None:
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
assert isinstance(param, ShardedParamV2)
|
||||
for t in param.get_payload_tensors():
|
||||
assert isinstance(t, StatefulTensor)
|
||||
|
|
|
@ -4,8 +4,8 @@ import torch
|
|||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from typing import Type
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import torch
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from typing import Union, Tuple
|
||||
|
||||
|
||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
||||
if issubclass(type(tensor), StatefulTensor):
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
t = tensor.payload
|
||||
elif isinstance(tensor, torch.Tensor):
|
||||
t = tensor
|
||||
|
@ -24,23 +24,24 @@ def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[
|
|||
|
||||
def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
|
||||
torch.Tensor]) -> None:
|
||||
"""
|
||||
A colossal API for model data tensor move.
|
||||
"""
|
||||
A colossal API for model data tensor move.
|
||||
The src and target tensors could be resident on both CPU and GPU.
|
||||
|
||||
|
||||
NOTE() The source tensor payload will be removed after this function.
|
||||
|
||||
|
||||
The function will record the communication volume between CPU and GPU.
|
||||
Args:
|
||||
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
|
||||
"""
|
||||
if issubclass(type(src_t), StatefulTensor):
|
||||
if isinstance(src_t, StatefulTensor):
|
||||
src_t_payload = src_t.payload
|
||||
else:
|
||||
src_t_payload = src_t.data
|
||||
src_dev = src_t_payload.device
|
||||
if issubclass(type(tgt_t), StatefulTensor):
|
||||
|
||||
if isinstance(tgt_t, StatefulTensor):
|
||||
tgt_t_payload = tgt_t.payload
|
||||
else:
|
||||
tgt_t_payload = tgt_t.data
|
||||
|
@ -48,70 +49,56 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
|
|||
tgt_t_payload.copy_(src_t_payload)
|
||||
|
||||
# remove payload of src_t
|
||||
if issubclass(type(src_t), StatefulTensor):
|
||||
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
|
||||
if isinstance(src_t, StatefulTensor):
|
||||
src_t.set_null()
|
||||
else:
|
||||
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
||||
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
|
||||
|
||||
|
||||
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
|
||||
int]) -> None:
|
||||
"""
|
||||
"""
|
||||
move a tensor to the target_device
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
||||
target_device: a traget device, if type is int, it the index of cuda card.
|
||||
"""
|
||||
if isinstance(t, torch.Tensor):
|
||||
t_payload = t
|
||||
elif issubclass(type(t), StatefulTensor):
|
||||
t_payload = t.payload
|
||||
else:
|
||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
if not isinstance(target_device, torch.device):
|
||||
target_device = torch.device(f'cuda:{target_device}')
|
||||
|
||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||
if t_payload.device.type == target_device.type:
|
||||
return
|
||||
t_payload.data = t_payload.data.to(target_device)
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.to(target_device)
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(target_device)
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
|
||||
|
||||
|
||||
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
||||
"""colo_model_data_move_to_cpu
|
||||
|
||||
"""colo_model_data_move_to_cpu
|
||||
move a model data tensor from gpu to cpu
|
||||
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): _description_
|
||||
"""
|
||||
|
||||
if issubclass(type(t), StatefulTensor):
|
||||
t_payload = t.payload
|
||||
elif isinstance(t, torch.Tensor):
|
||||
t_payload = t
|
||||
else:
|
||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
if t_payload.device.type == 'cpu':
|
||||
return
|
||||
|
||||
# TODO() optimize the tensor moving with non-blocking
|
||||
t_payload.data = t_payload.data.cpu()
|
||||
if isinstance(t, torch.Tensor):
|
||||
t.data = t.data.cpu()
|
||||
elif isinstance(t, StatefulTensor):
|
||||
t.move_to(torch.device('cpu'))
|
||||
else:
|
||||
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Clone a model data tensor
|
||||
|
||||
Args:
|
||||
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
|
||||
target_device (torch.device): the target device
|
||||
Returns:
|
||||
torch.Tensor: a cloned torch tensor
|
||||
"""
|
||||
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
|
||||
|
||||
ret = t_payload.to(target_device)
|
||||
return ret
|
||||
# TODO() rename this function
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t_payload = t.payload if isinstance(t, StatefulTensor) else t
|
||||
return t_payload
|
|
@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
|
|||
class ForceFP32Parameter(torch.nn.Parameter):
|
||||
|
||||
def half(self, memory_format=None):
|
||||
return self.data
|
||||
return self.data.clone()
|
||||
|
||||
|
||||
class NormalNoiseGenerator:
|
||||
|
|
|
@ -35,4 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
|||
return zero_model, zero_optimizer
|
||||
|
||||
|
||||
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||
|
|
|
@ -184,11 +184,12 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
if param.grad is not None:
|
||||
param.grad = param.grad.to(target_device)
|
||||
|
||||
param.colo_attr = ShardedParamV2(param, set_data_none=False)
|
||||
param.colo_attr = ShardedParamV2(param, set_data_none=True)
|
||||
|
||||
if self.shard_param:
|
||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||
|
||||
# mark whether the param is replicated
|
||||
param.colo_attr.is_replicated = self.is_replicated
|
||||
|
|
|
@ -31,9 +31,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
|||
for i in range(world_size):
|
||||
if i == rank:
|
||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||
# Release payload here, to decrease peak memory usage
|
||||
for t in tensor_list:
|
||||
t.reset_payload(None)
|
||||
else:
|
||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
||||
|
@ -44,6 +41,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
|||
for i, t in enumerate(tensor_list):
|
||||
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
||||
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
t.payload_reset(gathered_payload)
|
||||
t.is_sharded = False
|
||||
offset += tensor_numels[i]
|
||||
|
|
|
@ -3,10 +3,10 @@ from typing import List, Optional
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.shard_utils.commons import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
|
@ -36,7 +36,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||
f" but current cuda device is {get_current_device()}"
|
||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||
t.reset_payload(sharded_payload)
|
||||
t.payload_reset(sharded_payload)
|
||||
t.is_sharded = True
|
||||
|
||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||
|
@ -53,6 +53,6 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
|
||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||
t.reset_payload(gathered_payload)
|
||||
t.payload_reset(gathered_payload)
|
||||
colo_model_data_tensor_move_inline(t, target_device)
|
||||
t.is_sharded = False
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, Callable, List, Tuple
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Union
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
||||
|
|
|
@ -17,11 +17,11 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
|||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_device_memory_capacity
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||
|
||||
|
@ -358,8 +358,11 @@ class ShardedModelV2(nn.Module):
|
|||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
|
||||
param.colo_attr.reset_grad_payload(grad.data)
|
||||
param.colo_attr.reset_data_payload(grad.data) # release the memory of param
|
||||
param.colo_attr.grad_payload_reset(grad.data)
|
||||
# release the memory of param
|
||||
# we set a false None for parameter's payload
|
||||
# so we can get paramter's device and dtype later in optimizer
|
||||
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
|
||||
|
||||
if param.colo_attr.is_replicated:
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
|
@ -368,7 +371,7 @@ class ShardedModelV2(nn.Module):
|
|||
fp32_grad = cast_tensor_to_fp32(grad)
|
||||
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
param.colo_attr.reset_grad_payload(fp32_grad)
|
||||
param.colo_attr.grad_payload_reset(fp32_grad)
|
||||
else:
|
||||
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
||||
|
||||
|
|
|
@ -12,15 +12,15 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
|
||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||
|
||||
|
||||
|
@ -253,7 +253,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
for p in group['params']:
|
||||
# p.colo_attr.sharded_data_tensor stores grad now
|
||||
# we have to recover fp16 param
|
||||
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
||||
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
|
||||
if recover_data and reuse_fp16_shard:
|
||||
self._copy_master_param_to_param_fp16(p)
|
||||
else:
|
||||
|
@ -332,12 +332,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
|
||||
def _copy_master_param_to_param_fp16(self, p):
|
||||
# flush gradient
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
if p.colo_attr.sharded_data_tensor.payload_size == 0:
|
||||
# here reuse_fp16_shard is True
|
||||
# in order to use copy below, we should give sharded data tensor a payload
|
||||
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
|
||||
else:
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
||||
p.data = self.master_params[p].payload
|
||||
|
||||
# we need to allocate new memory for keep_not_shard paramters
|
||||
# in order to use copy, otherwise, the sizes of tensor is not compatible
|
||||
if p.colo_attr.data_payload.numel() != p.data.numel():
|
||||
p.colo_attr.data_payload_reset(
|
||||
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.data = self.master_params[p].payload
|
||||
p.colo_attr.reset_data_payload(
|
||||
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
|
||||
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
|
||||
p.colo_attr.set_data_none()
|
||||
|
||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||
|
|
|
@ -1,11 +1,5 @@
|
|||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
|
||||
colo_model_data_move_to_cpu, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
|
||||
|
||||
__all__ = [
|
||||
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
|
||||
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
|
||||
]
|
||||
'ShardedTensor', 'ShardedParamV2']
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional, Tuple
|
||||
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
|
||||
from .tensorful_state import StatefulTensor, TensorState
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
from typing import List
|
||||
|
||||
EMPTY_TENSOR_DICT = {}
|
||||
|
@ -50,6 +50,7 @@ class ShardedParamV2(object):
|
|||
|
||||
@property
|
||||
def data_payload(self):
|
||||
assert not self.sharded_data_tensor.is_null()
|
||||
return self.sharded_data_tensor.payload
|
||||
|
||||
@property
|
||||
|
@ -61,15 +62,15 @@ class ShardedParamV2(object):
|
|||
def param_is_sharded(self):
|
||||
return self.sharded_data_tensor.is_sharded
|
||||
|
||||
def reset_data_payload(self, tensor: torch.Tensor):
|
||||
def data_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.sharded_data_tensor.reset_payload(tensor)
|
||||
self.sharded_data_tensor.payload_reset(tensor)
|
||||
|
||||
def reset_grad_payload(self, tensor: torch.Tensor):
|
||||
def grad_payload_reset(self, tensor: torch.Tensor):
|
||||
assert type(tensor) is torch.Tensor
|
||||
assert tensor.requires_grad is False
|
||||
self.saved_grad.reset_payload(tensor)
|
||||
self.saved_grad.payload_reset(tensor)
|
||||
|
||||
def get_memory_usage(self) -> Tuple[int, int]:
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
import torch
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
from typing import Optional
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedTensor(StatefulTensor):
|
||||
|
|
|
@ -1,80 +0,0 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
import torch
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
HOLD = 1
|
||||
HOLD_AFTER_FWD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
COMPUTE = 4
|
||||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
Inspired from the paper:
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
self._payload = tensor
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None, f"payload has to None if state is {self._state}"
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return None
|
||||
return self._payload.data_ptr()
|
||||
|
||||
@property
|
||||
def state(self) -> TensorState:
|
||||
return self._state
|
||||
|
||||
def set_null(self) -> None:
|
||||
self._state = TensorState.FREE
|
||||
self._payload = None
|
||||
|
||||
def is_null(self) -> bool:
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None
|
||||
return True
|
||||
return False
|
||||
|
||||
def trans_state(self, state: TensorState) -> None:
|
||||
self._state = state
|
||||
if state == TensorState.FREE:
|
||||
self._payload = None
|
||||
|
||||
@property
|
||||
def payload(self) -> Optional[torch.Tensor]:
|
||||
return self._payload
|
||||
|
||||
def copy_payload(self, tensor) -> None:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def reset_payload(self, tensor) -> None:
|
||||
del self._payload
|
||||
self._payload = tensor
|
||||
self.trans_state(TensorState.HOLD)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
return self._payload.dtype
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
|
@ -8,12 +8,11 @@ from colossalai.registry import OPHOOKS
|
|||
from colossalai.utils import get_current_device
|
||||
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
|
||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from typing import Any
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_gemini_manager():
|
||||
# reset the manager, in case that there exists memory information left
|
||||
manager = StatefulTensor.GST_MGR
|
||||
manager.reset()
|
||||
|
||||
# occupation 8
|
||||
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
|
||||
# occupation 60
|
||||
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
|
||||
|
||||
# occupation 28
|
||||
t1 = torch.empty(7, device='cuda')
|
||||
# occupation 12
|
||||
t2 = torch.empty(3, device='cpu')
|
||||
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
|
||||
st4 = StatefulTensor(None, TensorState.FREE)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 60
|
||||
assert manager.total_mem['cuda'] == 36
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
|
||||
|
||||
st4.payload_reset(t2)
|
||||
st3.payload_reset(t2)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 84
|
||||
assert manager.total_mem['cuda'] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
|
||||
st1.move_to(torch.device('cpu'))
|
||||
st2.move_to(torch.device('cpu'))
|
||||
st3.move_to(torch.device('cuda', 0))
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
|
||||
st1.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.COMPUTE)
|
||||
st2.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||
|
||||
assert manager.total_number == 4
|
||||
assert manager.total_mem['cpu'] == 80
|
||||
assert manager.total_mem['cuda'] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
|
||||
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
|
||||
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
|
||||
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_gemini_manager()
|
|
@ -6,9 +6,8 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory import colo_set_process_memory_fraction
|
||||
from colossalai.gemini import StatefulTensorMgr
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from colossalai.gemini.stateful_tensor import TensorState
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from torch.nn.parameter import Parameter
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||
import colossalai
|
||||
|
||||
import torch
|
||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
|
|||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from tests.test_zero.common import CONFIG, allclose
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
|
|
|
@ -2,9 +2,10 @@ import pytest
|
|||
|
||||
import colossalai
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||
colo_model_tensor_clone)
|
||||
from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||
colo_model_tensor_clone)
|
||||
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
|
||||
|
|
Loading…
Reference in New Issue