[utils] update colo tensor moving APIs (#553)

pull/558/head
Jiarui Fang 3 years ago committed by GitHub
parent c44d797072
commit d1211148a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,14 +1,14 @@
import torch import torch
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
from typing import Tuple, Union from typing import Tuple, Union
_GLOBAL_CUDA_MEM_FRACTION = 1.0 _GLOBAL_CUDA_MEM_FRACTION = 1.0
def colo_tensor_mem_usage(tensor: Union[torch.Tensor, ShardedTensor]) -> Tuple[int, int]: def colo_tensor_mem_usage(tensor: Union[torch.Tensor, StatefulTensor]) -> Tuple[int, int]:
if isinstance(tensor, ShardedTensor): if issubclass(type(tensor), StatefulTensor):
t = tensor.payload t = tensor.payload
elif isinstance(tensor, torch.Tensor): elif isinstance(tensor, torch.Tensor):
t = tensor t = tensor
@ -46,8 +46,8 @@ def colo_cuda_memory_capacity() -> float:
return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION return torch.cuda.get_device_properties(get_current_device()).total_memory * _GLOBAL_CUDA_MEM_FRACTION
def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t: Union[ShardedTensor, def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_t: Union[StatefulTensor,
torch.Tensor]) -> None: torch.Tensor]) -> None:
""" """
A colossal API for model data tensor move. A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU. The src and target tensors could be resident on both CPU and GPU.
@ -56,46 +56,44 @@ def colo_model_data_tensor_move(src_t: Union[ShardedTensor, torch.Tensor], tgt_t
The function will record the communication volume between CPU and GPU. The function will record the communication volume between CPU and GPU.
Args: Args:
t_src (Union[ShardedTensor, torch.Tensor]): source tensor t_src (Union[StatefulTensor, torch.Tensor]): source tensor
tgt_t (Union[ShardedTensor, torch.Tensor]): target tensor tgt_t (Union[StatefulTensor, torch.Tensor]): target tensor
""" """
if isinstance(src_t, ShardedTensor): if issubclass(type(src_t), StatefulTensor):
src_t_payload = src_t.payload src_t_payload = src_t.payload
else: else:
src_t_payload = src_t.data src_t_payload = src_t.data
src_dev = src_t_payload.device src_dev = src_t_payload.device
if isinstance(tgt_t, ShardedTensor): if issubclass(type(tgt_t), StatefulTensor):
tgt_t_payload = tgt_t.payload tgt_t_payload = tgt_t.payload
else: else:
tgt_t_payload = tgt_t.data tgt_t_payload = tgt_t.data
tgt_dev = tgt_t_payload.device
tgt_t_payload.copy_(src_t_payload) tgt_t_payload.copy_(src_t_payload)
# remove payload of src_t # remove payload of src_t
if isinstance(src_t, ShardedTensor): if issubclass(type(src_t), StatefulTensor):
src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)) src_t.reset_payload(torch.tensor([], device=src_dev, dtype=src_t_payload.dtype))
else: else:
src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype) src_t.data = torch.tensor([], device=src_dev, dtype=src_t_payload.dtype)
def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor], def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], target_device: Union[torch.device,
target_device: torch.device, int]) -> None:
use_tracer: bool = True) -> None:
""" """
move a tensor to the target_device move a tensor to the target_device
Args: Args:
t (Union[ShardedTensor, torch.Tensor]): the tensor be moved t (Union[StatefulTensor, torch.Tensor]): the tensor be moved
""" """
if isinstance(t, torch.Tensor):
if isinstance(t, ShardedTensor):
t_payload = t.payload
elif isinstance(t, torch.Tensor):
t_payload = t t_payload = t
elif issubclass(type(t), StatefulTensor):
t_payload = t.payload
else: else:
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}') raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
assert isinstance(target_device, torch.device) if isinstance(target_device, int):
target_device = torch.cuda(f'device"{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0) # deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type: if t_payload.device.type == target_device.type:
@ -103,16 +101,16 @@ def colo_model_data_tensor_move_inline(t: Union[ShardedTensor, torch.Tensor],
t_payload.data = t_payload.data.to(target_device) t_payload.data = t_payload.data.to(target_device)
def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None: def colo_model_data_move_to_cpu(t: Union[StatefulTensor, torch.Tensor]) -> None:
"""colo_model_data_move_to_cpu """colo_model_data_move_to_cpu
move a model data tensor from gpu to cpu move a model data tensor from gpu to cpu
Args: Args:
t (Union[ShardedTensor, torch.Tensor]): _description_ t (Union[StatefulTensor, torch.Tensor]): _description_
""" """
if isinstance(t, ShardedTensor): if issubclass(type(t), StatefulTensor):
t_payload = t.payload t_payload = t.payload
elif isinstance(t, torch.Tensor): elif isinstance(t, torch.Tensor):
t_payload = t t_payload = t
@ -126,17 +124,17 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
t_payload.data = t_payload.data.cpu() t_payload.data = t_payload.data.cpu()
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor: def colo_model_tensor_clone(t: Union[StatefulTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
""" """
Clone a model data tensor Clone a model data tensor
Args: Args:
t (Union[ShardedTensor, torch.Tensor]): a model data tensor t (Union[StatefulTensor, torch.Tensor]): a model data tensor
target_device (torch.device): the target device target_device (torch.device): the target device
Returns: Returns:
torch.Tensor: a cloned torch tensor torch.Tensor: a cloned torch tensor
""" """
t_payload = t.payload if isinstance(t, ShardedTensor) else t t_payload = t.payload if issubclass(type(t), StatefulTensor) else t
ret = t_payload.to(target_device) ret = t_payload.to(target_device)
return ret return ret

Loading…
Cancel
Save