mirror of https://github.com/hpcaitech/ColossalAI
[gemini] add GeminiMemoryManger (#832)
* refactor StatefulTensor, tensor utilities * add unitest for GeminiMemoryManagerpull/837/head
parent
35ea6e1023
commit
e5ea3fdeef
|
@ -0,0 +1,45 @@
|
||||||
|
from enum import EnumMeta
|
||||||
|
|
||||||
|
|
||||||
|
class GeminiMemoryManager(object):
|
||||||
|
|
||||||
|
def __init__(self, states_cls: EnumMeta):
|
||||||
|
super().__init__()
|
||||||
|
self.states_cls = states_cls
|
||||||
|
self._cnter = 0 # the counter of instances
|
||||||
|
|
||||||
|
self.total_mem = dict()
|
||||||
|
self.state_mem = dict()
|
||||||
|
self.state_mem['cpu'] = dict()
|
||||||
|
self.state_mem['cuda'] = dict()
|
||||||
|
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_number(self):
|
||||||
|
return self._cnter
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self._cnter = 0 # the counter of instances
|
||||||
|
|
||||||
|
self.total_mem['cpu'] = 0 # memory occupation of instances in cpu
|
||||||
|
self.total_mem['cuda'] = 0 # memory of occupation of instances in cuda
|
||||||
|
|
||||||
|
# memory conditions for all states
|
||||||
|
for state in self.states_cls:
|
||||||
|
self.state_mem['cpu'][state] = 0
|
||||||
|
self.state_mem['cuda'][state] = 0
|
||||||
|
|
||||||
|
def register_new_instance(self):
|
||||||
|
self._cnter += 1
|
||||||
|
|
||||||
|
def print_info(self):
|
||||||
|
print(
|
||||||
|
f"Total number: {self.total_number}",
|
||||||
|
f"Total CPU memory occupation: {self.total_mem['cpu']}",
|
||||||
|
f"Total CUDA memory occupation: {self.total_mem['cuda']}\n", sep='\n')
|
||||||
|
|
||||||
|
for state in self.states_cls:
|
||||||
|
print(
|
||||||
|
f"{state}: CPU memory occupation: {self.state_mem['cpu'][state]}",
|
||||||
|
f"{state}: CUDA memory occupation: {self.state_mem['cuda'][state]}\n", sep='\n')
|
|
@ -0,0 +1,204 @@
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Optional
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from colossalai.gemini.gemini_context import GeminiMemoryManager
|
||||||
|
|
||||||
|
|
||||||
|
def sizeof_tensor(tensor: torch.Tensor):
|
||||||
|
return tensor.numel() * tensor.element_size()
|
||||||
|
|
||||||
|
|
||||||
|
class TensorState(Enum):
|
||||||
|
FREE = 0
|
||||||
|
HOLD = 1
|
||||||
|
HOLD_AFTER_FWD = 2
|
||||||
|
HOLD_AFTER_BWD = 3
|
||||||
|
COMPUTE = 4
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulTensor(object):
|
||||||
|
"""A Structure stores a Torch Tensor and labeled states.
|
||||||
|
Inspired from the paper:
|
||||||
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2108.05818
|
||||||
|
"""
|
||||||
|
# Global Stateful Tensor Manager
|
||||||
|
GST_MGR = GeminiMemoryManager(TensorState)
|
||||||
|
|
||||||
|
def __init__(self, maybe_tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
||||||
|
self._state = state
|
||||||
|
self._payload = None
|
||||||
|
self._payload_size = 0 # byte size of current payload
|
||||||
|
|
||||||
|
StatefulTensor.GST_MGR.register_new_instance()
|
||||||
|
|
||||||
|
if self._state == TensorState.FREE:
|
||||||
|
# when the state is free, payload should be None
|
||||||
|
assert maybe_tensor is None, f"payload has to None if state is {self._state}"
|
||||||
|
else:
|
||||||
|
# otherwise, payload should not be None
|
||||||
|
assert maybe_tensor is not None, f"payload can't be None if state is {self._state}"
|
||||||
|
self._payload = maybe_tensor
|
||||||
|
self._payload_size = sizeof_tensor(maybe_tensor)
|
||||||
|
self.__trans_state_update(TensorState.FREE, state)
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
if self._payload is None:
|
||||||
|
return 0 # if a tensor has no storage, 0 should be returned
|
||||||
|
return self._payload.data_ptr()
|
||||||
|
|
||||||
|
def set_null(self) -> None:
|
||||||
|
# notice that free stateful tensor do not need to become null again
|
||||||
|
if self.state != TensorState.FREE:
|
||||||
|
self.__trans_state_update(self.state, TensorState.FREE)
|
||||||
|
self.__release()
|
||||||
|
|
||||||
|
def is_null(self) -> bool:
|
||||||
|
if self.state == TensorState.FREE:
|
||||||
|
# check sanity here
|
||||||
|
assert self.payload is None
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trans_state(self, state: TensorState) -> None:
|
||||||
|
if self.state == TensorState.FREE:
|
||||||
|
# free stateful tensor can't change state
|
||||||
|
assert state == TensorState.FREE, "Free stateful tensor can't change to other states"
|
||||||
|
return
|
||||||
|
|
||||||
|
self.__trans_state_update(self.state, state)
|
||||||
|
|
||||||
|
if state == TensorState.FREE:
|
||||||
|
self.__release()
|
||||||
|
else:
|
||||||
|
self._state = state
|
||||||
|
|
||||||
|
def move_to(self, device: Union[torch.device, int]):
|
||||||
|
assert self.state is not TensorState.FREE, "Can't move free stateful tensor"
|
||||||
|
|
||||||
|
if not isinstance(device, torch.device):
|
||||||
|
to_device = torch.device('cuda', device)
|
||||||
|
else:
|
||||||
|
to_device = device
|
||||||
|
|
||||||
|
from_device_type = self.device.type
|
||||||
|
if from_device_type == to_device.type:
|
||||||
|
# from device == to device
|
||||||
|
return
|
||||||
|
|
||||||
|
# update manager's information
|
||||||
|
self.__trans_device_update(from_device_type, to_device.type)
|
||||||
|
self.payload.data = self.payload.data.to(to_device)
|
||||||
|
|
||||||
|
def payload_copy(self, tensor) -> None:
|
||||||
|
self._payload.view(-1).copy_(tensor.view(-1))
|
||||||
|
|
||||||
|
def payload_reset(self, tensor) -> None:
|
||||||
|
|
||||||
|
assert tensor is not None, "Can't reset None for stateful tensors, please use set_null() instead"
|
||||||
|
|
||||||
|
if self.payload is not None:
|
||||||
|
# release old payload
|
||||||
|
self.__trans_state_update(self.state, TensorState.FREE)
|
||||||
|
else:
|
||||||
|
# otherwise, set the state to HOLD for new payload
|
||||||
|
self._state = TensorState.HOLD
|
||||||
|
del self._payload
|
||||||
|
|
||||||
|
self._payload = tensor
|
||||||
|
self._payload_size = sizeof_tensor(tensor)
|
||||||
|
# record new payload
|
||||||
|
self.__trans_state_update(TensorState.FREE, self.state)
|
||||||
|
|
||||||
|
def payload_relay(self, rhs):
|
||||||
|
# relay the payload of rhs to current stateful tensor
|
||||||
|
# can't support null relay right now
|
||||||
|
assert not rhs.is_null()
|
||||||
|
|
||||||
|
# now this function only support stateful tensor that has zero-length payload
|
||||||
|
# because it doesn't require memory manager updating
|
||||||
|
# you can extend this function by yourself
|
||||||
|
assert self.payload_size == 0
|
||||||
|
|
||||||
|
self._payload = rhs.payload
|
||||||
|
self._payload_size = rhs.payload_size
|
||||||
|
self._state = TensorState.HOLD
|
||||||
|
self.__trans_state_update(rhs.state, TensorState.HOLD)
|
||||||
|
|
||||||
|
rhs.__release()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def payload(self) -> Optional[torch.Tensor]:
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
@property
|
||||||
|
def payload_size(self) -> int:
|
||||||
|
return self._payload_size
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> TensorState:
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self._payload.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
return self._payload.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self._payload.shape
|
||||||
|
|
||||||
|
def to(self, device: torch.device):
|
||||||
|
raise RuntimeError("Use move_to(...) instead of call .to() on StatefulTensor")
|
||||||
|
|
||||||
|
def to_(self, device: torch.device):
|
||||||
|
raise RuntimeError("Use move_to(...) instead of call .to_() on StatefulTensor")
|
||||||
|
|
||||||
|
def __release(self):
|
||||||
|
# release current payload
|
||||||
|
# shouldn't be visible to users
|
||||||
|
self._state = TensorState.FREE
|
||||||
|
self._payload = None
|
||||||
|
self._payload_size = 0
|
||||||
|
|
||||||
|
def __trans_state_update(self, from_state: TensorState, to_state: TensorState):
|
||||||
|
"""Update global manager when changing the state of a tensor
|
||||||
|
"""
|
||||||
|
manager = StatefulTensor.GST_MGR
|
||||||
|
size = self.payload_size
|
||||||
|
device_type = self.device.type
|
||||||
|
|
||||||
|
if from_state != TensorState.FREE:
|
||||||
|
manager.state_mem[device_type][from_state] -= size
|
||||||
|
else:
|
||||||
|
# when from_state is FREE, the tensor is new to manager
|
||||||
|
# we should add its memory
|
||||||
|
manager.total_mem[device_type] += size
|
||||||
|
|
||||||
|
if to_state != TensorState.FREE:
|
||||||
|
manager.state_mem[device_type][to_state] += size
|
||||||
|
else:
|
||||||
|
# when to_state is FREE, the tensor will be deleted soon
|
||||||
|
# we should sub its memory
|
||||||
|
manager.total_mem[device_type] -= size
|
||||||
|
|
||||||
|
def __trans_device_update(self, from_type: str, to_type: str):
|
||||||
|
"""Update global manager when changing the device of a tensor
|
||||||
|
"""
|
||||||
|
manager = StatefulTensor.GST_MGR
|
||||||
|
size = self.payload_size
|
||||||
|
state = self.state
|
||||||
|
|
||||||
|
# update aggregated information
|
||||||
|
manager.total_mem[from_type] -= size
|
||||||
|
manager.total_mem[to_type] += size
|
||||||
|
|
||||||
|
# update the information of each state
|
||||||
|
manager.state_mem[from_type][state] -= size
|
||||||
|
manager.state_mem[to_type][state] += size
|
|
@ -2,9 +2,8 @@ import functools
|
||||||
import torch
|
import torch
|
||||||
import types
|
import types
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
|
||||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicy
|
||||||
from typing import List
|
from typing import List
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
@ -30,7 +29,8 @@ class StatefulTensorMgr(object):
|
||||||
|
|
||||||
self._cpu_gpu_move_volume = 0
|
self._cpu_gpu_move_volume = 0
|
||||||
|
|
||||||
def register_stateful_param(self, param: ShardedParamV2) -> None:
|
def register_stateful_param(self, param) -> None:
|
||||||
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
assert isinstance(param, ShardedParamV2)
|
assert isinstance(param, ShardedParamV2)
|
||||||
for t in param.get_payload_tensors():
|
for t in param.get_payload_tensors():
|
||||||
assert isinstance(t, StatefulTensor)
|
assert isinstance(t, StatefulTensor)
|
||||||
|
|
|
@ -4,8 +4,8 @@ import torch
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
|
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline, colo_tensor_mem_usage
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||||
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from typing import Type
|
from typing import Type
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
|
|
||||||
|
|
||||||
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
|
||||||
if issubclass(type(tensor), StatefulTensor):
|
if isinstance(tensor, StatefulTensor):
|
||||||
t = tensor.payload
|
t = tensor.payload
|
||||||
elif isinstance(tensor, torch.Tensor):
|
elif isinstance(tensor, torch.Tensor):
|
||||||
t = tensor
|
t = tensor
|
||||||
|
@ -32,15 +32,16 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
|
||||||
|
|
||||||
The function will record the communication volume between CPU and GPU.
|
The function will record the communication volume between CPU and GPU.
|
||||||
Args:
|
Args:
|
||||||
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
|
src_t (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||||
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
|
tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
|
||||||
"""
|
"""
|
||||||
if issubclass(type(src_t), StatefulTensor):
|
if isinstance(src_t, StatefulTensor):
|
||||||
src_t_payload = src_t.payload
|
src_t_payload = src_t.payload
|
||||||
else:
|
else:
|
||||||
src_t_payload = src_t.data
|
src_t_payload = src_t.data
|
||||||
src_dev = src_t_payload.device
|
src_dev = src_t_payload.device
|
||||||
if issubclass(type(tgt_t), StatefulTensor):
|
|
||||||
|
if isinstance(tgt_t, StatefulTensor):
|
||||||
tgt_t_payload = tgt_t.payload
|
tgt_t_payload = tgt_t.payload
|
||||||
else:
|
else:
|
||||||
tgt_t_payload = tgt_t.data
|
tgt_t_payload = tgt_t.data
|
||||||
|
@ -48,10 +49,10 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
|
||||||
tgt_t_payload.copy_(src_t_payload)
|
tgt_t_payload.copy_(src_t_payload)
|
||||||
|
|
||||||
# remove payload of src_t
|
# remove payload of src_t
|
||||||
if issubclass(type(src_t), StatefulTensor):
|
if isinstance(src_t, StatefulTensor):
|
||||||
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
|
src_t.set_null()
|
||||||
else:
|
else:
|
||||||
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
|
src_t.data = torch.empty(0, device=src_dev, dtype=src_t_payload.dtype)
|
||||||
|
|
||||||
|
|
||||||
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
|
def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
|
||||||
|
@ -62,56 +63,42 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
|
||||||
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
|
||||||
target_device: a traget device, if type is int, it the index of cuda card.
|
target_device: a traget device, if type is int, it the index of cuda card.
|
||||||
"""
|
"""
|
||||||
if isinstance(t, torch.Tensor):
|
|
||||||
t_payload = t
|
|
||||||
elif issubclass(type(t), StatefulTensor):
|
|
||||||
t_payload = t.payload
|
|
||||||
else:
|
|
||||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
|
||||||
|
|
||||||
if not isinstance(target_device, torch.device):
|
if not isinstance(target_device, torch.device):
|
||||||
target_device = torch.device(f'cuda:{target_device}')
|
target_device = torch.device(f'cuda:{target_device}')
|
||||||
|
|
||||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
if isinstance(t, torch.Tensor):
|
||||||
if t_payload.device.type == target_device.type:
|
t.data = t.data.to(target_device)
|
||||||
return
|
elif isinstance(t, StatefulTensor):
|
||||||
t_payload.data = t_payload.data.to(target_device)
|
t.move_to(target_device)
|
||||||
|
else:
|
||||||
|
raise TypeError(f'colo_model_data_tensor_move_inline dose not accept type {type(t)}')
|
||||||
|
|
||||||
|
|
||||||
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
|
||||||
"""colo_model_data_move_to_cpu
|
"""colo_model_data_move_to_cpu
|
||||||
|
|
||||||
move a model data tensor from gpu to cpu
|
move a model data tensor from gpu to cpu
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
t (Union[StatefulTensor, torch.Tensor]): _description_
|
t (Union[StatefulTensor, torch.Tensor]): _description_
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if issubclass(type(t), StatefulTensor):
|
|
||||||
t_payload = t.payload
|
|
||||||
elif isinstance(t, torch.Tensor):
|
|
||||||
t_payload = t
|
|
||||||
else:
|
|
||||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
|
||||||
|
|
||||||
if t_payload.device.type == 'cpu':
|
|
||||||
return
|
|
||||||
|
|
||||||
# TODO() optimize the tensor moving with non-blocking
|
# TODO() optimize the tensor moving with non-blocking
|
||||||
t_payload.data = t_payload.data.cpu()
|
if isinstance(t, torch.Tensor):
|
||||||
|
t.data = t.data.cpu()
|
||||||
|
elif isinstance(t, StatefulTensor):
|
||||||
|
t.move_to(torch.device('cpu'))
|
||||||
|
else:
|
||||||
|
raise TypeError(f'colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||||
|
|
||||||
|
|
||||||
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Clone a model data tensor
|
Clone a model data tensor
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
|
t (Union[StatefulTensor, torch.Tensor]): a model data tensor
|
||||||
target_device (torch.device): the target device
|
target_device (torch.device): the target device
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: a cloned torch tensor
|
torch.Tensor: a cloned torch tensor
|
||||||
"""
|
"""
|
||||||
t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
|
# TODO() rename this function
|
||||||
|
colo_model_data_tensor_move_inline(t, target_device)
|
||||||
ret = t_payload.to(target_device)
|
t_payload = t.payload if isinstance(t, StatefulTensor) else t
|
||||||
return ret
|
return t_payload
|
|
@ -8,7 +8,7 @@ from .experts import FFNExperts, TPExperts
|
||||||
class ForceFP32Parameter(torch.nn.Parameter):
|
class ForceFP32Parameter(torch.nn.Parameter):
|
||||||
|
|
||||||
def half(self, memory_format=None):
|
def half(self, memory_format=None):
|
||||||
return self.data
|
return self.data.clone()
|
||||||
|
|
||||||
|
|
||||||
class NormalNoiseGenerator:
|
class NormalNoiseGenerator:
|
||||||
|
|
|
@ -35,4 +35,4 @@ def convert_to_zero_v2(model: nn.Module, optimizer: torch.optim.Optimizer, model
|
||||||
return zero_model, zero_optimizer
|
return zero_model, zero_optimizer
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['convert_to_zerov2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
__all__ = ['convert_to_zero_v2', 'ShardedModelV2', 'ShardedOptimizerV2']
|
||||||
|
|
|
@ -184,10 +184,11 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
param.grad = param.grad.to(target_device)
|
param.grad = param.grad.to(target_device)
|
||||||
|
|
||||||
param.colo_attr = ShardedParamV2(param, set_data_none=False)
|
param.colo_attr = ShardedParamV2(param, set_data_none=True)
|
||||||
|
|
||||||
if self.shard_param:
|
if self.shard_param:
|
||||||
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([param.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
|
||||||
param.data = param.colo_attr.data_payload # set param.data to payload
|
param.data = param.colo_attr.data_payload # set param.data to payload
|
||||||
|
|
||||||
# mark whether the param is replicated
|
# mark whether the param is replicated
|
||||||
|
|
|
@ -31,9 +31,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||||
for i in range(world_size):
|
for i in range(world_size):
|
||||||
if i == rank:
|
if i == rank:
|
||||||
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
buffer_list.append(flatten([t.payload for t in tensor_list]).cuda(get_current_device()))
|
||||||
# Release payload here, to decrease peak memory usage
|
|
||||||
for t in tensor_list:
|
|
||||||
t.reset_payload(None)
|
|
||||||
else:
|
else:
|
||||||
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
buffer_list.append(torch.zeros(buffer_size, dtype=dtype, device=get_current_device()))
|
||||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
dist.all_gather(buffer_list, buffer_list[rank], group=process_group)
|
||||||
|
@ -44,6 +41,6 @@ class BucketTensorShardStrategy(TensorShardStrategy):
|
||||||
for i, t in enumerate(tensor_list):
|
for i, t in enumerate(tensor_list):
|
||||||
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
gathered_payload = [buffer[offset:offset + tensor_numels[i]] for buffer in buffer_list]
|
||||||
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
gathered_payload = torch.cat(gathered_payload)[:t.origin_numel].view(t.origin_shape)
|
||||||
t.reset_payload(gathered_payload)
|
t.payload_reset(gathered_payload)
|
||||||
t.is_sharded = False
|
t.is_sharded = False
|
||||||
offset += tensor_numels[i]
|
offset += tensor_numels[i]
|
||||||
|
|
|
@ -3,10 +3,10 @@ from typing import List, Optional
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move_inline
|
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.shard_utils.commons import get_shard
|
from colossalai.zero.shard_utils.commons import get_shard
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move_inline
|
||||||
|
|
||||||
|
|
||||||
class TensorShardStrategy(BaseShardStrategy):
|
class TensorShardStrategy(BaseShardStrategy):
|
||||||
|
@ -36,7 +36,7 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||||
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
assert t.payload.device == get_current_device(), f"shard tensor on cuda device index {t.payload.device.index},"\
|
||||||
f" but current cuda device is {get_current_device()}"
|
f" but current cuda device is {get_current_device()}"
|
||||||
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
sharded_payload, _ = get_shard(t.payload, dist.get_rank(process_group), dist.get_world_size(process_group))
|
||||||
t.reset_payload(sharded_payload)
|
t.payload_reset(sharded_payload)
|
||||||
t.is_sharded = True
|
t.is_sharded = True
|
||||||
|
|
||||||
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
def _gather_tensor(self, t: ShardedTensor, process_group: Optional[dist.ProcessGroup] = None):
|
||||||
|
@ -53,6 +53,6 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||||
|
|
||||||
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False)
|
||||||
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||||
t.reset_payload(gathered_payload)
|
t.payload_reset(gathered_payload)
|
||||||
colo_model_data_tensor_move_inline(t, target_device)
|
colo_model_data_tensor_move_inline(t, target_device)
|
||||||
t.is_sharded = False
|
t.is_sharded = False
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Any, Callable, List, Tuple
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
|
|
||||||
|
|
||||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
def get_gradient_predivide_factor(world_size: int) -> float:
|
||||||
|
|
|
@ -17,11 +17,11 @@ from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory import colo_device_memory_capacity
|
from colossalai.utils.memory import colo_device_memory_capacity
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_move_to_cpu
|
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
from colossalai.gemini.tensor_utils import colo_model_data_move_to_cpu
|
||||||
|
from colossalai.gemini.stateful_tensor import TensorState
|
||||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
from colossalai.gemini.tensor_placement_policy import TensorPlacementPolicyFactory, TensorPlacementPolicy
|
||||||
|
|
||||||
|
@ -358,8 +358,11 @@ class ShardedModelV2(nn.Module):
|
||||||
assert param.colo_attr.saved_grad.is_null(
|
assert param.colo_attr.saved_grad.is_null(
|
||||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
|
|
||||||
param.colo_attr.reset_grad_payload(grad.data)
|
param.colo_attr.grad_payload_reset(grad.data)
|
||||||
param.colo_attr.reset_data_payload(grad.data) # release the memory of param
|
# release the memory of param
|
||||||
|
# we set a false None for parameter's payload
|
||||||
|
# so we can get paramter's device and dtype later in optimizer
|
||||||
|
param.colo_attr.data_payload_reset(torch.empty(0, device=grad.device, dtype=grad.dtype))
|
||||||
|
|
||||||
if param.colo_attr.is_replicated:
|
if param.colo_attr.is_replicated:
|
||||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||||
|
@ -368,7 +371,7 @@ class ShardedModelV2(nn.Module):
|
||||||
fp32_grad = cast_tensor_to_fp32(grad)
|
fp32_grad = cast_tensor_to_fp32(grad)
|
||||||
|
|
||||||
if param.colo_attr.saved_grad.is_null():
|
if param.colo_attr.saved_grad.is_null():
|
||||||
param.colo_attr.reset_grad_payload(fp32_grad)
|
param.colo_attr.grad_payload_reset(fp32_grad)
|
||||||
else:
|
else:
|
||||||
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
param.colo_attr.grad_payload.add_(fp32_grad.view_as(param.colo_attr.grad_payload))
|
||||||
|
|
||||||
|
|
|
@ -12,15 +12,15 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
from colossalai.gemini.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
from colossalai.gemini.tensor_utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||||
colo_tensor_mem_usage)
|
colo_tensor_mem_usage)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from colossalai.gemini.stateful_tensor import (StatefulTensor, TensorState)
|
||||||
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
from colossalai.gemini.tensor_placement_policy import AutoTensorPlacementPolicy
|
||||||
|
|
||||||
|
|
||||||
|
@ -253,7 +253,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
# p.colo_attr.sharded_data_tensor stores grad now
|
# p.colo_attr.sharded_data_tensor stores grad now
|
||||||
# we have to recover fp16 param
|
# we have to recover fp16 param
|
||||||
reuse_fp16_shard = p.colo_attr.saved_grad.data_ptr() == p.colo_attr.sharded_data_tensor.data_ptr()
|
reuse_fp16_shard = (p.colo_attr.sharded_data_tensor.payload_size == 0)
|
||||||
if recover_data and reuse_fp16_shard:
|
if recover_data and reuse_fp16_shard:
|
||||||
self._copy_master_param_to_param_fp16(p)
|
self._copy_master_param_to_param_fp16(p)
|
||||||
else:
|
else:
|
||||||
|
@ -332,12 +332,23 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
|
|
||||||
def _copy_master_param_to_param_fp16(self, p):
|
def _copy_master_param_to_param_fp16(self, p):
|
||||||
# flush gradient
|
# flush gradient
|
||||||
|
if p.colo_attr.sharded_data_tensor.payload_size == 0:
|
||||||
|
# here reuse_fp16_shard is True
|
||||||
|
# in order to use copy below, we should give sharded data tensor a payload
|
||||||
|
p.colo_attr.sharded_data_tensor.payload_relay(p.colo_attr.saved_grad)
|
||||||
|
else:
|
||||||
p.colo_attr.saved_grad.set_null()
|
p.colo_attr.saved_grad.set_null()
|
||||||
|
|
||||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
|
||||||
p.data = self.master_params[p].payload
|
p.data = self.master_params[p].payload
|
||||||
p.colo_attr.reset_data_payload(
|
|
||||||
colo_model_tensor_clone(p.half().detach(), p.colo_attr.sharded_data_tensor.device))
|
# we need to allocate new memory for keep_not_shard paramters
|
||||||
|
# in order to use copy, otherwise, the sizes of tensor is not compatible
|
||||||
|
if p.colo_attr.data_payload.numel() != p.data.numel():
|
||||||
|
p.colo_attr.data_payload_reset(
|
||||||
|
torch.empty(p.data.shape, dtype=p.colo_attr.data_payload.dtype, device=p.colo_attr.data_payload.device))
|
||||||
|
|
||||||
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||||
|
p.colo_attr.sharded_data_tensor.payload_copy(p.half().detach())
|
||||||
p.colo_attr.set_data_none()
|
p.colo_attr.set_data_none()
|
||||||
|
|
||||||
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
if p.colo_attr.keep_not_shard and p.colo_attr.is_replicated:
|
||||||
|
|
|
@ -1,11 +1,5 @@
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensor_utils import (colo_model_data_tensor_move, colo_model_data_tensor_move_inline,
|
|
||||||
colo_model_data_move_to_cpu, colo_model_tensor_clone,
|
|
||||||
colo_tensor_mem_usage)
|
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState, StatefulTensor
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ShardedTensor', 'ShardedParamV2', 'colo_model_data_tensor_move', 'colo_model_data_tensor_move_inline',
|
'ShardedTensor', 'ShardedParamV2']
|
||||||
'colo_model_data_move_to_cpu', 'colo_model_tensor_clone', 'colo_tensor_mem_usage', 'TensorState', 'StatefulTensor'
|
|
||||||
]
|
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_tensor_mem_usage
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
from .tensorful_state import StatefulTensor, TensorState
|
from colossalai.gemini.tensor_utils import colo_tensor_mem_usage
|
||||||
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
EMPTY_TENSOR_DICT = {}
|
EMPTY_TENSOR_DICT = {}
|
||||||
|
@ -50,6 +50,7 @@ class ShardedParamV2(object):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def data_payload(self):
|
def data_payload(self):
|
||||||
|
assert not self.sharded_data_tensor.is_null()
|
||||||
return self.sharded_data_tensor.payload
|
return self.sharded_data_tensor.payload
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -61,15 +62,15 @@ class ShardedParamV2(object):
|
||||||
def param_is_sharded(self):
|
def param_is_sharded(self):
|
||||||
return self.sharded_data_tensor.is_sharded
|
return self.sharded_data_tensor.is_sharded
|
||||||
|
|
||||||
def reset_data_payload(self, tensor: torch.Tensor):
|
def data_payload_reset(self, tensor: torch.Tensor):
|
||||||
assert type(tensor) is torch.Tensor
|
assert type(tensor) is torch.Tensor
|
||||||
assert tensor.requires_grad is False
|
assert tensor.requires_grad is False
|
||||||
self.sharded_data_tensor.reset_payload(tensor)
|
self.sharded_data_tensor.payload_reset(tensor)
|
||||||
|
|
||||||
def reset_grad_payload(self, tensor: torch.Tensor):
|
def grad_payload_reset(self, tensor: torch.Tensor):
|
||||||
assert type(tensor) is torch.Tensor
|
assert type(tensor) is torch.Tensor
|
||||||
assert tensor.requires_grad is False
|
assert tensor.requires_grad is False
|
||||||
self.saved_grad.reset_payload(tensor)
|
self.saved_grad.payload_reset(tensor)
|
||||||
|
|
||||||
def get_memory_usage(self) -> Tuple[int, int]:
|
def get_memory_usage(self) -> Tuple[int, int]:
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
from colossalai.gemini.stateful_tensor import StatefulTensor, TensorState
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class ShardedTensor(StatefulTensor):
|
class ShardedTensor(StatefulTensor):
|
||||||
|
|
|
@ -1,80 +0,0 @@
|
||||||
from enum import Enum
|
|
||||||
from typing import Optional
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
class TensorState(Enum):
|
|
||||||
FREE = 0
|
|
||||||
HOLD = 1
|
|
||||||
HOLD_AFTER_FWD = 2
|
|
||||||
HOLD_AFTER_BWD = 3
|
|
||||||
COMPUTE = 4
|
|
||||||
|
|
||||||
|
|
||||||
class StatefulTensor(object):
|
|
||||||
"""A Structure stores a Torch Tensor and labeled states.
|
|
||||||
Inspired from the paper:
|
|
||||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
|
||||||
|
|
||||||
https://arxiv.org/abs/2108.05818
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, tensor: Optional[torch.Tensor], state: Optional[TensorState] = TensorState.HOLD) -> None:
|
|
||||||
self._state = state
|
|
||||||
self._payload = tensor
|
|
||||||
if self._state == TensorState.FREE:
|
|
||||||
assert self._payload is None, f"payload has to None if state is {self._state}"
|
|
||||||
|
|
||||||
def data_ptr(self):
|
|
||||||
if self._payload is None:
|
|
||||||
return None
|
|
||||||
return self._payload.data_ptr()
|
|
||||||
|
|
||||||
@property
|
|
||||||
def state(self) -> TensorState:
|
|
||||||
return self._state
|
|
||||||
|
|
||||||
def set_null(self) -> None:
|
|
||||||
self._state = TensorState.FREE
|
|
||||||
self._payload = None
|
|
||||||
|
|
||||||
def is_null(self) -> bool:
|
|
||||||
if self._state == TensorState.FREE:
|
|
||||||
assert self._payload is None
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def trans_state(self, state: TensorState) -> None:
|
|
||||||
self._state = state
|
|
||||||
if state == TensorState.FREE:
|
|
||||||
self._payload = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def payload(self) -> Optional[torch.Tensor]:
|
|
||||||
return self._payload
|
|
||||||
|
|
||||||
def copy_payload(self, tensor) -> None:
|
|
||||||
self._payload.view(-1).copy_(tensor.view(-1))
|
|
||||||
|
|
||||||
def reset_payload(self, tensor) -> None:
|
|
||||||
del self._payload
|
|
||||||
self._payload = tensor
|
|
||||||
self.trans_state(TensorState.HOLD)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
return self._payload.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self) -> torch.dtype:
|
|
||||||
return self._payload.dtype
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
return self._payload.shape
|
|
||||||
|
|
||||||
def to(self, device: torch.device):
|
|
||||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
|
||||||
|
|
||||||
def to_(self, device: torch.device):
|
|
||||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
|
|
@ -8,12 +8,11 @@ from colossalai.registry import OPHOOKS
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
|
||||||
from colossalai.engine.ophooks import BaseOpHook
|
from colossalai.engine.ophooks import BaseOpHook
|
||||||
|
|
||||||
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
||||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||||
from typing import Any
|
from colossalai.gemini.stateful_tensor import TensorState
|
||||||
|
|
||||||
|
|
||||||
@OPHOOKS.register_module
|
@OPHOOKS.register_module
|
||||||
|
|
|
@ -0,0 +1,73 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from colossalai.gemini.stateful_tensor import TensorState, StatefulTensor
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_gemini_manager():
|
||||||
|
# reset the manager, in case that there exists memory information left
|
||||||
|
manager = StatefulTensor.GST_MGR
|
||||||
|
manager.reset()
|
||||||
|
|
||||||
|
# occupation 8
|
||||||
|
st1 = StatefulTensor(torch.empty(2, 2, dtype=torch.float16, device='cuda'))
|
||||||
|
# occupation 60
|
||||||
|
st2 = StatefulTensor(torch.empty(3, 5, dtype=torch.float32, device='cpu'))
|
||||||
|
|
||||||
|
# occupation 28
|
||||||
|
t1 = torch.empty(7, device='cuda')
|
||||||
|
# occupation 12
|
||||||
|
t2 = torch.empty(3, device='cpu')
|
||||||
|
st3 = StatefulTensor(t1, TensorState.HOLD_AFTER_FWD)
|
||||||
|
st4 = StatefulTensor(None, TensorState.FREE)
|
||||||
|
|
||||||
|
assert manager.total_number == 4
|
||||||
|
assert manager.total_mem['cpu'] == 60
|
||||||
|
assert manager.total_mem['cuda'] == 36
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 60
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 28
|
||||||
|
|
||||||
|
st4.payload_reset(t2)
|
||||||
|
st3.payload_reset(t2)
|
||||||
|
|
||||||
|
assert manager.total_number == 4
|
||||||
|
assert manager.total_mem['cpu'] == 84
|
||||||
|
assert manager.total_mem['cuda'] == 8
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 72
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 8
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 12
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 0
|
||||||
|
|
||||||
|
st1.move_to(torch.device('cpu'))
|
||||||
|
st2.move_to(torch.device('cpu'))
|
||||||
|
st3.move_to(torch.device('cuda', 0))
|
||||||
|
|
||||||
|
assert manager.total_number == 4
|
||||||
|
assert manager.total_mem['cpu'] == 80
|
||||||
|
assert manager.total_mem['cuda'] == 12
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 80
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||||
|
|
||||||
|
st1.trans_state(TensorState.COMPUTE)
|
||||||
|
st2.trans_state(TensorState.COMPUTE)
|
||||||
|
st2.trans_state(TensorState.HOLD_AFTER_BWD)
|
||||||
|
|
||||||
|
assert manager.total_number == 4
|
||||||
|
assert manager.total_mem['cpu'] == 80
|
||||||
|
assert manager.total_mem['cuda'] == 12
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD] == 12
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD] == 0
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_FWD] == 0
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_FWD] == 12
|
||||||
|
assert manager.state_mem['cpu'][TensorState.HOLD_AFTER_BWD] == 60
|
||||||
|
assert manager.state_mem['cuda'][TensorState.HOLD_AFTER_BWD] == 0
|
||||||
|
assert manager.state_mem['cpu'][TensorState.COMPUTE] == 8
|
||||||
|
assert manager.state_mem['cuda'][TensorState.COMPUTE] == 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_gemini_manager()
|
|
@ -6,9 +6,8 @@ from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.gemini.memory_tracer import MemStatsCollector
|
from colossalai.gemini.memory_tracer import MemStatsCollector
|
||||||
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
from colossalai.gemini.memory_tracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory import colo_set_process_memory_fraction
|
from colossalai.utils.memory import colo_set_process_memory_fraction
|
||||||
from colossalai.gemini import StatefulTensorMgr
|
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
from colossalai.gemini.stateful_tensor import TensorState
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from colossalai.zero.sharded_param.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
|
from colossalai.gemini.tensor_utils import colo_model_data_tensor_move, colo_model_data_tensor_move_inline
|
||||||
import colossalai
|
import colossalai
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.zero.shard_utils import (BucketTensorShardStrategy, TensorShardS
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from tests.test_zero.common import CONFIG, allclose
|
from tests.test_zero.common import CONFIG, allclose
|
||||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
|
|
||||||
|
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
|
|
|
@ -2,9 +2,10 @@ import pytest
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.zero.sharded_param import (StatefulTensor, colo_tensor_mem_usage, colo_model_data_tensor_move,
|
from colossalai.gemini.tensor_utils import (colo_tensor_mem_usage, colo_model_data_tensor_move,
|
||||||
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
colo_model_data_tensor_move_inline, colo_model_data_move_to_cpu,
|
||||||
colo_model_tensor_clone)
|
colo_model_tensor_clone)
|
||||||
|
from colossalai.gemini.stateful_tensor import StatefulTensor
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue