[zero] add stateful tensor (#549)

pull/552/head
Jiarui Fang 3 years ago committed by GitHub
parent 107b99ddb1
commit 214da761d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -64,8 +64,8 @@ class ZeroHook(BaseOpHook):
if param.grad is not None: if param.grad is not None:
if param.col_attr.bwd_count == 0: if param.col_attr.bwd_count == 0:
# We haven't stored local accumulated grad yet # We haven't stored local accumulated grad yet
assert param.col_attr.fp32_grad is None assert param.col_attr.fp32_grad.is_null()
param.col_attr.fp32_grad = param.grad.data param.col_attr.fp32_grad.reset_payload(param.grad.data)
param.grad = None param.grad = None
else: else:
# We have stored local accumulated grad # We have stored local accumulated grad

@ -2,6 +2,8 @@ from typing import Any, Callable, List, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from typing import Union
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
def get_gradient_predivide_factor(world_size: int) -> float: def get_gradient_predivide_factor(world_size: int) -> float:
@ -30,12 +32,17 @@ def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor: def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
return tensor.half() return tensor.half()
return tensor return tensor
def cast_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor: def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor:
if isinstance(tensor, StatefulTensor):
tensor = tensor.payload
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
return tensor.float() return tensor.float()
return tensor return tensor

@ -25,6 +25,7 @@ from torch.nn.parameter import Parameter
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
get_gradient_predivide_factor) get_gradient_predivide_factor)
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
@ -233,16 +234,17 @@ class ShardedModelV2(nn.Module):
if self.reuse_fp16_shard: if self.reuse_fp16_shard:
grad_payload = p.col_attr.sharded_data_tensor.payload grad_payload = p.col_attr.sharded_data_tensor.payload
else: else:
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad) grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
assert isinstance(grad_payload, torch.Tensor)
if p.col_attr.offload_grad: if p.col_attr.offload_grad:
colo_model_data_move_to_cpu(grad_payload) colo_model_data_move_to_cpu(grad_payload)
if p.col_attr.fp32_grad is not None: if not p.col_attr.fp32_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad)) p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload))
grad_payload = p.col_attr.fp32_grad grad_payload = p.col_attr.fp32_grad.payload
p.grad.data = grad_payload p.grad.data = grad_payload
p.col_attr.fp16_grad = None p.col_attr.fp16_grad.set_null()
p.col_attr.fp32_grad = None p.col_attr.fp32_grad.set_null()
@torch.no_grad() @torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -293,6 +295,8 @@ class ShardedModelV2(nn.Module):
return empty_grad return empty_grad
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
assert isinstance(reduced_grad,
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
reduced_grad = reduced_grad.view(-1) reduced_grad = reduced_grad.view(-1)
if self.gradient_postdivide_factor > 1: if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
@ -301,7 +305,7 @@ class ShardedModelV2(nn.Module):
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data) param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
param.col_attr.sharded_data_tensor.is_sharded = True param.col_attr.sharded_data_tensor.is_sharded = True
else: else:
param.col_attr.fp16_grad = reduced_grad.data param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()], self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],

@ -3,6 +3,7 @@ import torch.distributed as dist
from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param import ShardedTensor
from typing import Optional, Tuple from typing import Optional, Tuple
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
from .tensorful_state import StatefulTensor, TensorState
class ShardedParamV2(object): class ShardedParamV2(object):
@ -12,8 +13,8 @@ class ShardedParamV2(object):
process_group: Optional[dist.ProcessGroup] = None, process_group: Optional[dist.ProcessGroup] = None,
rm_torch_payload=False) -> None: rm_torch_payload=False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group) self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
self.fp16_grad: Optional[torch.Tensor] = None self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
self.fp32_grad: Optional[torch.Tensor] = None self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel # This attribute must be initialized in ShardedModel
self.offload_grad: bool = False self.offload_grad: bool = False
@ -64,12 +65,12 @@ class ShardedParamV2(object):
_update_mem_use(self.sharded_data_tensor.payload) _update_mem_use(self.sharded_data_tensor.payload)
address_set.add(self.sharded_data_tensor.payload.data_ptr()) address_set.add(self.sharded_data_tensor.payload.data_ptr())
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set: if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set:
_update_mem_use(self.fp16_grad) _update_mem_use(self.fp16_grad.payload)
address_set.add(self.fp16_grad.data_ptr()) address_set.add(self.fp16_grad.data_ptr())
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set: if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set:
_update_mem_use(self.fp32_grad) _update_mem_use(self.fp32_grad.payload)
address_set.add(self.fp32_grad.data_ptr()) address_set.add(self.fp32_grad.data_ptr())
if self.param.data is not None and self.param.data.data_ptr() not in address_set: if self.param.data is not None and self.param.data.data_ptr() not in address_set:

@ -1,30 +1,30 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from typing import Optional from typing import Optional
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
class ShardedTensor(object): class ShardedTensor(StatefulTensor):
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None: def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
r""" r"""
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
""" """
self._payload = tensor super().__init__(tensor)
self.process_group = process_group self.trans_state(TensorState.HOLD)
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
self._is_sharded = False
self._origin_shape = tensor.shape self._origin_shape = tensor.shape
self._origin_numel = tensor.numel() self._origin_numel = tensor.numel()
self._origin_dtype = tensor.dtype self._origin_dtype = tensor.dtype
self._is_sharded = False
@property @property
def origin_numel(self): def origin_numel(self) -> int:
return self._origin_numel return self._origin_numel
@property @property
def origin_shape(self): def origin_shape(self) -> int:
return self._origin_shape return self._origin_shape
@property @property
@ -34,33 +34,3 @@ class ShardedTensor(object):
@is_sharded.setter @is_sharded.setter
def is_sharded(self, flag: bool): def is_sharded(self, flag: bool):
self._is_sharded = flag self._is_sharded = flag
@property
def payload(self):
return self._payload
def copy_payload(self, tensor):
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor):
del self._payload
self._payload = tensor
@property
def device(self):
return self._payload.device
@property
def dtype(self):
assert self._payload.dtype == self._origin_dtype
return self._origin_dtype
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
@property
def shape(self):
return self._payload.shape

@ -0,0 +1,81 @@
from enum import Enum
from logging import NullHandler
import torch
class TensorState(Enum):
FREE = 0
HOLD = 1
HOLD_AFTER_FWD = 2
HOLD_AFTER_BWD = 3
class StatefulTensor(object):
"""A Structure stores a Torch Tensor and labeled states.
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
https://arxiv.org/abs/2108.05818
"""
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
self._state = state
if state is not TensorState.FREE:
self._payload = tensor
else:
self._payload = None
def data_ptr(self):
if self._payload is None:
return None
return self._payload.data_ptr()
@property
def state(self) -> TensorState:
return self._state
def set_null(self) -> None:
self._state = TensorState.FREE
self._payload = None
def is_null(self) -> bool:
if self._state == TensorState.FREE:
assert self._payload is None
return True
return False
def trans_state(self, state: TensorState) -> None:
self._state = state
if state == TensorState.FREE:
self._payload = None
@property
def payload(self) -> int:
return self._payload
def copy_payload(self, tensor) -> int:
self._payload.view(-1).copy_(tensor.view(-1))
def reset_payload(self, tensor) -> int:
del self._payload
self._payload = tensor
self.trans_state(TensorState.HOLD)
@property
def device(self) -> torch.device:
return self._payload.device
@property
def dtype(self) -> torch.dtype:
assert self._payload.dtype == self._origin_dtype
return self._origin_dtype
def to(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
def to_(self, device: torch.device):
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
@property
def shape(self):
return self._payload.shape

@ -12,6 +12,7 @@ from colossalai.zero.sharded_param import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
from colossalai.testing import rerun_on_exception from colossalai.testing import rerun_on_exception
from tests.test_zero_data_parallel.common import CONFIG, allclose from tests.test_zero_data_parallel.common import CONFIG, allclose
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
@ -52,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
allclose(sparam.sharded_data_tensor.payload, param_ref.data) allclose(sparam.sharded_data_tensor.payload, param_ref.data)
# Test get memory usage # Test get memory usage
sparam.fp32_grad = torch.randn(2, 3) sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
@ -62,13 +63,13 @@ def _run_shard_param_v2(rank, world_size, port):
# 4 is size of dummy tensor of param.data # 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
sparam.fp16_grad = torch.randn(2, 3).cuda().half() sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cuda_mem_use == 2 * 3 * 2 assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = None sparam.fp16_grad = StatefulTensor(None)
sparam.fp32_grad = torch.randn(2, 3) sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload() sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
@ -82,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
assert cuda_mem_use == 0 assert cuda_mem_use == 0
# reuse torch grad for sparam # reuse torch grad for sparam
sparam.fp32_grad = param.grad sparam.fp32_grad = StatefulTensor(param.grad)
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 assert cpu_mem_use == 2 * 3 * 4 * 2
assert cuda_mem_use == 0 assert cuda_mem_use == 0

Loading…
Cancel
Save