mirror of https://github.com/hpcaitech/ColossalAI
[zero] add stateful tensor (#549)
parent
107b99ddb1
commit
214da761d4
|
@ -64,8 +64,8 @@ class ZeroHook(BaseOpHook):
|
|||
if param.grad is not None:
|
||||
if param.col_attr.bwd_count == 0:
|
||||
# We haven't stored local accumulated grad yet
|
||||
assert param.col_attr.fp32_grad is None
|
||||
param.col_attr.fp32_grad = param.grad.data
|
||||
assert param.col_attr.fp32_grad.is_null()
|
||||
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
||||
param.grad = None
|
||||
else:
|
||||
# We have stored local accumulated grad
|
||||
|
|
|
@ -2,6 +2,8 @@ from typing import Any, Callable, List, Tuple
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Union
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
|
||||
|
||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
||||
|
@ -30,12 +32,17 @@ def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
|
|||
|
||||
|
||||
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||
return tensor.half()
|
||||
return tensor
|
||||
|
||||
|
||||
def cast_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
|
||||
def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor:
|
||||
if isinstance(tensor, StatefulTensor):
|
||||
tensor = tensor.payload
|
||||
|
||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
||||
return tensor.float()
|
||||
return tensor
|
||||
|
|
|
@ -25,6 +25,7 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||
get_gradient_predivide_factor)
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
|
@ -233,16 +234,17 @@ class ShardedModelV2(nn.Module):
|
|||
if self.reuse_fp16_shard:
|
||||
grad_payload = p.col_attr.sharded_data_tensor.payload
|
||||
else:
|
||||
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
||||
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
||||
assert isinstance(grad_payload, torch.Tensor)
|
||||
if p.col_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(grad_payload)
|
||||
if p.col_attr.fp32_grad is not None:
|
||||
if not p.col_attr.fp32_grad.is_null():
|
||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad))
|
||||
grad_payload = p.col_attr.fp32_grad
|
||||
p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload))
|
||||
grad_payload = p.col_attr.fp32_grad.payload
|
||||
p.grad.data = grad_payload
|
||||
p.col_attr.fp16_grad = None
|
||||
p.col_attr.fp32_grad = None
|
||||
p.col_attr.fp16_grad.set_null()
|
||||
p.col_attr.fp32_grad.set_null()
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
|
@ -293,6 +295,8 @@ class ShardedModelV2(nn.Module):
|
|||
return empty_grad
|
||||
|
||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||
assert isinstance(reduced_grad,
|
||||
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
||||
reduced_grad = reduced_grad.view(-1)
|
||||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
|
@ -301,7 +305,7 @@ class ShardedModelV2(nn.Module):
|
|||
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
||||
param.col_attr.sharded_data_tensor.is_sharded = True
|
||||
else:
|
||||
param.col_attr.fp16_grad = reduced_grad.data
|
||||
param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
|
|
|
@ -3,6 +3,7 @@ import torch.distributed as dist
|
|||
from colossalai.zero.sharded_param import ShardedTensor
|
||||
from typing import Optional, Tuple
|
||||
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
||||
from .tensorful_state import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedParamV2(object):
|
||||
|
@ -12,8 +13,8 @@ class ShardedParamV2(object):
|
|||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
rm_torch_payload=False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||
self.fp16_grad: Optional[torch.Tensor] = None
|
||||
self.fp32_grad: Optional[torch.Tensor] = None
|
||||
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
self.offload_grad: bool = False
|
||||
|
||||
|
@ -64,12 +65,12 @@ class ShardedParamV2(object):
|
|||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||
|
||||
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp16_grad)
|
||||
if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp16_grad.payload)
|
||||
address_set.add(self.fp16_grad.data_ptr())
|
||||
|
||||
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp32_grad)
|
||||
if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp32_grad.payload)
|
||||
address_set.add(self.fp32_grad.data_ptr())
|
||||
|
||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||
|
|
|
@ -1,30 +1,30 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from typing import Optional
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||
|
||||
|
||||
class ShardedTensor(object):
|
||||
class ShardedTensor(StatefulTensor):
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||
r"""
|
||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||
"""
|
||||
self._payload = tensor
|
||||
self.process_group = process_group
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self._is_sharded = False
|
||||
super().__init__(tensor)
|
||||
self.trans_state(TensorState.HOLD)
|
||||
|
||||
self._origin_shape = tensor.shape
|
||||
self._origin_numel = tensor.numel()
|
||||
self._origin_dtype = tensor.dtype
|
||||
|
||||
self._is_sharded = False
|
||||
|
||||
@property
|
||||
def origin_numel(self):
|
||||
def origin_numel(self) -> int:
|
||||
return self._origin_numel
|
||||
|
||||
@property
|
||||
def origin_shape(self):
|
||||
def origin_shape(self) -> int:
|
||||
return self._origin_shape
|
||||
|
||||
@property
|
||||
|
@ -34,33 +34,3 @@ class ShardedTensor(object):
|
|||
@is_sharded.setter
|
||||
def is_sharded(self, flag: bool):
|
||||
self._is_sharded = flag
|
||||
|
||||
@property
|
||||
def payload(self):
|
||||
return self._payload
|
||||
|
||||
def copy_payload(self, tensor):
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def reset_payload(self, tensor):
|
||||
del self._payload
|
||||
self._payload = tensor
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._origin_dtype
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
from enum import Enum
|
||||
from logging import NullHandler
|
||||
import torch
|
||||
|
||||
|
||||
class TensorState(Enum):
|
||||
FREE = 0
|
||||
HOLD = 1
|
||||
HOLD_AFTER_FWD = 2
|
||||
HOLD_AFTER_BWD = 3
|
||||
|
||||
|
||||
class StatefulTensor(object):
|
||||
"""A Structure stores a Torch Tensor and labeled states.
|
||||
|
||||
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||
|
||||
https://arxiv.org/abs/2108.05818
|
||||
"""
|
||||
|
||||
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||
self._state = state
|
||||
if state is not TensorState.FREE:
|
||||
self._payload = tensor
|
||||
else:
|
||||
self._payload = None
|
||||
|
||||
def data_ptr(self):
|
||||
if self._payload is None:
|
||||
return None
|
||||
return self._payload.data_ptr()
|
||||
|
||||
@property
|
||||
def state(self) -> TensorState:
|
||||
return self._state
|
||||
|
||||
def set_null(self) -> None:
|
||||
self._state = TensorState.FREE
|
||||
self._payload = None
|
||||
|
||||
def is_null(self) -> bool:
|
||||
if self._state == TensorState.FREE:
|
||||
assert self._payload is None
|
||||
return True
|
||||
return False
|
||||
|
||||
def trans_state(self, state: TensorState) -> None:
|
||||
self._state = state
|
||||
if state == TensorState.FREE:
|
||||
self._payload = None
|
||||
|
||||
@property
|
||||
def payload(self) -> int:
|
||||
return self._payload
|
||||
|
||||
def copy_payload(self, tensor) -> int:
|
||||
self._payload.view(-1).copy_(tensor.view(-1))
|
||||
|
||||
def reset_payload(self, tensor) -> int:
|
||||
del self._payload
|
||||
self._payload = tensor
|
||||
self.trans_state(TensorState.HOLD)
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self._payload.device
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
assert self._payload.dtype == self._origin_dtype
|
||||
return self._origin_dtype
|
||||
|
||||
def to(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||
|
||||
def to_(self, device: torch.device):
|
||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
return self._payload.shape
|
|
@ -12,6 +12,7 @@ from colossalai.zero.sharded_param import ShardedTensor
|
|||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||
from colossalai.testing import rerun_on_exception
|
||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||
|
||||
|
||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||
|
@ -52,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||
|
||||
# Test get memory usage
|
||||
sparam.fp32_grad = torch.randn(2, 3)
|
||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||
|
||||
|
@ -62,13 +63,13 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
|
||||
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
|
||||
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
assert cuda_mem_use == 2 * 3 * 2
|
||||
|
||||
sparam.fp16_grad = None
|
||||
sparam.fp32_grad = torch.randn(2, 3)
|
||||
sparam.fp16_grad = StatefulTensor(None)
|
||||
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
|
@ -82,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
assert cuda_mem_use == 0
|
||||
|
||||
# reuse torch grad for sparam
|
||||
sparam.fp32_grad = param.grad
|
||||
sparam.fp32_grad = StatefulTensor(param.grad)
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||
assert cuda_mem_use == 0
|
||||
|
|
Loading…
Reference in New Issue