mirror of https://github.com/hpcaitech/ColossalAI
[zero] add stateful tensor (#549)
parent
107b99ddb1
commit
214da761d4
|
@ -64,8 +64,8 @@ class ZeroHook(BaseOpHook):
|
||||||
if param.grad is not None:
|
if param.grad is not None:
|
||||||
if param.col_attr.bwd_count == 0:
|
if param.col_attr.bwd_count == 0:
|
||||||
# We haven't stored local accumulated grad yet
|
# We haven't stored local accumulated grad yet
|
||||||
assert param.col_attr.fp32_grad is None
|
assert param.col_attr.fp32_grad.is_null()
|
||||||
param.col_attr.fp32_grad = param.grad.data
|
param.col_attr.fp32_grad.reset_payload(param.grad.data)
|
||||||
param.grad = None
|
param.grad = None
|
||||||
else:
|
else:
|
||||||
# We have stored local accumulated grad
|
# We have stored local accumulated grad
|
||||||
|
|
|
@ -2,6 +2,8 @@ from typing import Any, Callable, List, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from typing import Union
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||||
|
|
||||||
|
|
||||||
def get_gradient_predivide_factor(world_size: int) -> float:
|
def get_gradient_predivide_factor(world_size: int) -> float:
|
||||||
|
@ -30,12 +32,17 @@ def alloc_storage(data: torch.Tensor, size: torch.Size) -> None:
|
||||||
|
|
||||||
|
|
||||||
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor:
|
||||||
|
if isinstance(tensor, StatefulTensor):
|
||||||
|
tensor = tensor.payload
|
||||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
if torch.is_floating_point(tensor) and tensor.dtype is torch.float32:
|
||||||
return tensor.half()
|
return tensor.half()
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
def cast_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor:
|
def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor:
|
||||||
|
if isinstance(tensor, StatefulTensor):
|
||||||
|
tensor = tensor.payload
|
||||||
|
|
||||||
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
if torch.is_floating_point(tensor) and tensor.dtype is torch.float16:
|
||||||
return tensor.float()
|
return tensor.float()
|
||||||
return tensor
|
return tensor
|
||||||
|
|
|
@ -25,6 +25,7 @@ from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage,
|
||||||
get_gradient_predivide_factor)
|
get_gradient_predivide_factor)
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
|
@ -233,16 +234,17 @@ class ShardedModelV2(nn.Module):
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
grad_payload = p.col_attr.sharded_data_tensor.payload
|
grad_payload = p.col_attr.sharded_data_tensor.payload
|
||||||
else:
|
else:
|
||||||
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload)
|
||||||
|
assert isinstance(grad_payload, torch.Tensor)
|
||||||
if p.col_attr.offload_grad:
|
if p.col_attr.offload_grad:
|
||||||
colo_model_data_move_to_cpu(grad_payload)
|
colo_model_data_move_to_cpu(grad_payload)
|
||||||
if p.col_attr.fp32_grad is not None:
|
if not p.col_attr.fp32_grad.is_null():
|
||||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad))
|
p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload))
|
||||||
grad_payload = p.col_attr.fp32_grad
|
grad_payload = p.col_attr.fp32_grad.payload
|
||||||
p.grad.data = grad_payload
|
p.grad.data = grad_payload
|
||||||
p.col_attr.fp16_grad = None
|
p.col_attr.fp16_grad.set_null()
|
||||||
p.col_attr.fp32_grad = None
|
p.col_attr.fp32_grad.set_null()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
@ -293,6 +295,8 @@ class ShardedModelV2(nn.Module):
|
||||||
return empty_grad
|
return empty_grad
|
||||||
|
|
||||||
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None:
|
||||||
|
assert isinstance(reduced_grad,
|
||||||
|
torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}"
|
||||||
reduced_grad = reduced_grad.view(-1)
|
reduced_grad = reduced_grad.view(-1)
|
||||||
if self.gradient_postdivide_factor > 1:
|
if self.gradient_postdivide_factor > 1:
|
||||||
# Average grad by world_size for consistency with PyTorch DDP.
|
# Average grad by world_size for consistency with PyTorch DDP.
|
||||||
|
@ -301,7 +305,7 @@ class ShardedModelV2(nn.Module):
|
||||||
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
||||||
param.col_attr.sharded_data_tensor.is_sharded = True
|
param.col_attr.sharded_data_tensor.is_sharded = True
|
||||||
else:
|
else:
|
||||||
param.col_attr.fp16_grad = reduced_grad.data
|
param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||||
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||||
|
|
|
@ -3,6 +3,7 @@ import torch.distributed as dist
|
||||||
from colossalai.zero.sharded_param import ShardedTensor
|
from colossalai.zero.sharded_param import ShardedTensor
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage
|
||||||
|
from .tensorful_state import StatefulTensor, TensorState
|
||||||
|
|
||||||
|
|
||||||
class ShardedParamV2(object):
|
class ShardedParamV2(object):
|
||||||
|
@ -12,8 +13,8 @@ class ShardedParamV2(object):
|
||||||
process_group: Optional[dist.ProcessGroup] = None,
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
rm_torch_payload=False) -> None:
|
rm_torch_payload=False) -> None:
|
||||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group)
|
||||||
self.fp16_grad: Optional[torch.Tensor] = None
|
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||||
self.fp32_grad: Optional[torch.Tensor] = None
|
self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||||
# This attribute must be initialized in ShardedModel
|
# This attribute must be initialized in ShardedModel
|
||||||
self.offload_grad: bool = False
|
self.offload_grad: bool = False
|
||||||
|
|
||||||
|
@ -64,12 +65,12 @@ class ShardedParamV2(object):
|
||||||
_update_mem_use(self.sharded_data_tensor.payload)
|
_update_mem_use(self.sharded_data_tensor.payload)
|
||||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||||
|
|
||||||
if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set:
|
if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set:
|
||||||
_update_mem_use(self.fp16_grad)
|
_update_mem_use(self.fp16_grad.payload)
|
||||||
address_set.add(self.fp16_grad.data_ptr())
|
address_set.add(self.fp16_grad.data_ptr())
|
||||||
|
|
||||||
if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set:
|
if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set:
|
||||||
_update_mem_use(self.fp32_grad)
|
_update_mem_use(self.fp32_grad.payload)
|
||||||
address_set.add(self.fp32_grad.data_ptr())
|
address_set.add(self.fp32_grad.data_ptr())
|
||||||
|
|
||||||
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
if self.param.data is not None and self.param.data.data_ptr() not in address_set:
|
||||||
|
|
|
@ -1,30 +1,30 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState
|
||||||
|
|
||||||
|
|
||||||
class ShardedTensor(object):
|
class ShardedTensor(StatefulTensor):
|
||||||
|
|
||||||
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance.
|
||||||
"""
|
"""
|
||||||
self._payload = tensor
|
super().__init__(tensor)
|
||||||
self.process_group = process_group
|
self.trans_state(TensorState.HOLD)
|
||||||
self.world_size = dist.get_world_size(self.process_group)
|
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
|
||||||
self._is_sharded = False
|
|
||||||
|
|
||||||
self._origin_shape = tensor.shape
|
self._origin_shape = tensor.shape
|
||||||
self._origin_numel = tensor.numel()
|
self._origin_numel = tensor.numel()
|
||||||
self._origin_dtype = tensor.dtype
|
self._origin_dtype = tensor.dtype
|
||||||
|
|
||||||
|
self._is_sharded = False
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def origin_numel(self):
|
def origin_numel(self) -> int:
|
||||||
return self._origin_numel
|
return self._origin_numel
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def origin_shape(self):
|
def origin_shape(self) -> int:
|
||||||
return self._origin_shape
|
return self._origin_shape
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -34,33 +34,3 @@ class ShardedTensor(object):
|
||||||
@is_sharded.setter
|
@is_sharded.setter
|
||||||
def is_sharded(self, flag: bool):
|
def is_sharded(self, flag: bool):
|
||||||
self._is_sharded = flag
|
self._is_sharded = flag
|
||||||
|
|
||||||
@property
|
|
||||||
def payload(self):
|
|
||||||
return self._payload
|
|
||||||
|
|
||||||
def copy_payload(self, tensor):
|
|
||||||
self._payload.view(-1).copy_(tensor.view(-1))
|
|
||||||
|
|
||||||
def reset_payload(self, tensor):
|
|
||||||
del self._payload
|
|
||||||
self._payload = tensor
|
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self):
|
|
||||||
return self._payload.device
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
assert self._payload.dtype == self._origin_dtype
|
|
||||||
return self._origin_dtype
|
|
||||||
|
|
||||||
def to(self, device: torch.device):
|
|
||||||
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
|
||||||
|
|
||||||
def to_(self, device: torch.device):
|
|
||||||
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def shape(self):
|
|
||||||
return self._payload.shape
|
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
from enum import Enum
|
||||||
|
from logging import NullHandler
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
class TensorState(Enum):
|
||||||
|
FREE = 0
|
||||||
|
HOLD = 1
|
||||||
|
HOLD_AFTER_FWD = 2
|
||||||
|
HOLD_AFTER_BWD = 3
|
||||||
|
|
||||||
|
|
||||||
|
class StatefulTensor(object):
|
||||||
|
"""A Structure stores a Torch Tensor and labeled states.
|
||||||
|
|
||||||
|
PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management
|
||||||
|
|
||||||
|
https://arxiv.org/abs/2108.05818
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None:
|
||||||
|
self._state = state
|
||||||
|
if state is not TensorState.FREE:
|
||||||
|
self._payload = tensor
|
||||||
|
else:
|
||||||
|
self._payload = None
|
||||||
|
|
||||||
|
def data_ptr(self):
|
||||||
|
if self._payload is None:
|
||||||
|
return None
|
||||||
|
return self._payload.data_ptr()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def state(self) -> TensorState:
|
||||||
|
return self._state
|
||||||
|
|
||||||
|
def set_null(self) -> None:
|
||||||
|
self._state = TensorState.FREE
|
||||||
|
self._payload = None
|
||||||
|
|
||||||
|
def is_null(self) -> bool:
|
||||||
|
if self._state == TensorState.FREE:
|
||||||
|
assert self._payload is None
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def trans_state(self, state: TensorState) -> None:
|
||||||
|
self._state = state
|
||||||
|
if state == TensorState.FREE:
|
||||||
|
self._payload = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def payload(self) -> int:
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
def copy_payload(self, tensor) -> int:
|
||||||
|
self._payload.view(-1).copy_(tensor.view(-1))
|
||||||
|
|
||||||
|
def reset_payload(self, tensor) -> int:
|
||||||
|
del self._payload
|
||||||
|
self._payload = tensor
|
||||||
|
self.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def device(self) -> torch.device:
|
||||||
|
return self._payload.device
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self) -> torch.dtype:
|
||||||
|
assert self._payload.dtype == self._origin_dtype
|
||||||
|
return self._origin_dtype
|
||||||
|
|
||||||
|
def to(self, device: torch.device):
|
||||||
|
raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor")
|
||||||
|
|
||||||
|
def to_(self, device: torch.device):
|
||||||
|
raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self._payload.shape
|
|
@ -12,6 +12,7 @@ from colossalai.zero.sharded_param import ShardedTensor
|
||||||
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
|
||||||
from colossalai.testing import rerun_on_exception
|
from colossalai.testing import rerun_on_exception
|
||||||
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
from tests.test_zero_data_parallel.common import CONFIG, allclose
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import StatefulTensor
|
||||||
|
|
||||||
|
|
||||||
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
@parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy])
|
||||||
|
@ -52,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
allclose(sparam.sharded_data_tensor.payload, param_ref.data)
|
||||||
|
|
||||||
# Test get memory usage
|
# Test get memory usage
|
||||||
sparam.fp32_grad = torch.randn(2, 3)
|
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}"
|
||||||
|
|
||||||
|
@ -62,13 +63,13 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
# 4 is size of dummy tensor of param.data
|
# 4 is size of dummy tensor of param.data
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
|
|
||||||
sparam.fp16_grad = torch.randn(2, 3).cuda().half()
|
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
assert cuda_mem_use == 2 * 3 * 2
|
assert cuda_mem_use == 2 * 3 * 2
|
||||||
|
|
||||||
sparam.fp16_grad = None
|
sparam.fp16_grad = StatefulTensor(None)
|
||||||
sparam.fp32_grad = torch.randn(2, 3)
|
sparam.fp32_grad = StatefulTensor(torch.randn(2, 3))
|
||||||
sparam.remove_torch_payload()
|
sparam.remove_torch_payload()
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
|
@ -82,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
assert cuda_mem_use == 0
|
assert cuda_mem_use == 0
|
||||||
|
|
||||||
# reuse torch grad for sparam
|
# reuse torch grad for sparam
|
||||||
sparam.fp32_grad = param.grad
|
sparam.fp32_grad = StatefulTensor(param.grad)
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2
|
assert cpu_mem_use == 2 * 3 * 4 * 2
|
||||||
assert cuda_mem_use == 0
|
assert cuda_mem_use == 0
|
||||||
|
|
Loading…
Reference in New Issue