From 214da761d4df0461fa49bd23c501d661bbaa8436 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 30 Mar 2022 13:51:37 +0800 Subject: [PATCH] [zero] add stateful tensor (#549) --- colossalai/engine/ophooks/zero_hook.py | 4 +- colossalai/zero/sharded_model/_utils.py | 9 ++- .../zero/sharded_model/sharded_model_v2.py | 18 +++-- .../zero/sharded_param/sharded_param.py | 13 +-- .../zero/sharded_param/sharded_tensor.py | 46 ++--------- .../zero/sharded_param/tensorful_state.py | 81 +++++++++++++++++++ .../test_shard_param.py | 11 +-- 7 files changed, 123 insertions(+), 59 deletions(-) create mode 100644 colossalai/zero/sharded_param/tensorful_state.py diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index 01d3a08bb..755f08e9c 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -64,8 +64,8 @@ class ZeroHook(BaseOpHook): if param.grad is not None: if param.col_attr.bwd_count == 0: # We haven't stored local accumulated grad yet - assert param.col_attr.fp32_grad is None - param.col_attr.fp32_grad = param.grad.data + assert param.col_attr.fp32_grad.is_null() + param.col_attr.fp32_grad.reset_payload(param.grad.data) param.grad = None else: # We have stored local accumulated grad diff --git a/colossalai/zero/sharded_model/_utils.py b/colossalai/zero/sharded_model/_utils.py index 682a4ff1e..eed0ff964 100644 --- a/colossalai/zero/sharded_model/_utils.py +++ b/colossalai/zero/sharded_model/_utils.py @@ -2,6 +2,8 @@ from typing import Any, Callable, List, Tuple import torch import torch.nn.functional as F +from typing import Union +from colossalai.zero.sharded_param.tensorful_state import StatefulTensor def get_gradient_predivide_factor(world_size: int) -> float: @@ -30,12 +32,17 @@ def alloc_storage(data: torch.Tensor, size: torch.Size) -> None: def cast_tensor_to_fp16(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload if torch.is_floating_point(tensor) and tensor.dtype is torch.float32: return tensor.half() return tensor -def cast_tensor_to_fp32(tensor: torch.Tensor) -> torch.Tensor: +def cast_tensor_to_fp32(tensor: Union[torch.Tensor, StatefulTensor]) -> torch.Tensor: + if isinstance(tensor, StatefulTensor): + tensor = tensor.payload + if torch.is_floating_point(tensor) and tensor.dtype is torch.float16: return tensor.float() return tensor diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 68e69f301..03f45ce11 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -25,6 +25,7 @@ from torch.nn.parameter import Parameter from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) +from colossalai.zero.sharded_param.tensorful_state import StatefulTensor class ShardedModelV2(nn.Module): @@ -233,16 +234,17 @@ class ShardedModelV2(nn.Module): if self.reuse_fp16_shard: grad_payload = p.col_attr.sharded_data_tensor.payload else: - grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad) + grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload) + assert isinstance(grad_payload, torch.Tensor) if p.col_attr.offload_grad: colo_model_data_move_to_cpu(grad_payload) - if p.col_attr.fp32_grad is not None: + if not p.col_attr.fp32_grad.is_null(): assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' - p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad)) - grad_payload = p.col_attr.fp32_grad + p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload)) + grad_payload = p.col_attr.fp32_grad.payload p.grad.data = grad_payload - p.col_attr.fp16_grad = None - p.col_attr.fp32_grad = None + p.col_attr.fp16_grad.set_null() + p.col_attr.fp32_grad.set_null() @torch.no_grad() def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: @@ -293,6 +295,8 @@ class ShardedModelV2(nn.Module): return empty_grad def _reduce_scatter_callback(self, param: Parameter, reduced_grad: torch.Tensor) -> None: + assert isinstance(reduced_grad, + torch.Tensor), f"_reduce_scatter_callback accept reduced_grad as {type(reduced_grad)}" reduced_grad = reduced_grad.view(-1) if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. @@ -301,7 +305,7 @@ class ShardedModelV2(nn.Module): param.col_attr.sharded_data_tensor.reset_payload(reduced_grad.data) param.col_attr.sharded_data_tensor.is_sharded = True else: - param.col_attr.fp16_grad = reduced_grad.data + param.col_attr.fp16_grad = StatefulTensor(reduced_grad.data) def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': self.shard_strategy.gather([p.col_attr.sharded_data_tensor for p in self.module.parameters()], diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 64ef16555..b25e2e8a0 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -3,6 +3,7 @@ import torch.distributed as dist from colossalai.zero.sharded_param import ShardedTensor from typing import Optional, Tuple from colossalai.utils.memory_utils.utils import colo_tensor_mem_usage +from .tensorful_state import StatefulTensor, TensorState class ShardedParamV2(object): @@ -12,8 +13,8 @@ class ShardedParamV2(object): process_group: Optional[dist.ProcessGroup] = None, rm_torch_payload=False) -> None: self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group) - self.fp16_grad: Optional[torch.Tensor] = None - self.fp32_grad: Optional[torch.Tensor] = None + self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) + self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) # This attribute must be initialized in ShardedModel self.offload_grad: bool = False @@ -64,12 +65,12 @@ class ShardedParamV2(object): _update_mem_use(self.sharded_data_tensor.payload) address_set.add(self.sharded_data_tensor.payload.data_ptr()) - if self.fp16_grad is not None and self.fp16_grad.data_ptr() not in address_set: - _update_mem_use(self.fp16_grad) + if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set: + _update_mem_use(self.fp16_grad.payload) address_set.add(self.fp16_grad.data_ptr()) - if self.fp32_grad is not None and self.fp32_grad.data_ptr() not in address_set: - _update_mem_use(self.fp32_grad) + if not self.fp32_grad.is_null() and self.fp32_grad.data_ptr() not in address_set: + _update_mem_use(self.fp32_grad.payload) address_set.add(self.fp32_grad.data_ptr()) if self.param.data is not None and self.param.data.data_ptr() not in address_set: diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index c678f22da..59dc899ed 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -1,30 +1,30 @@ import torch import torch.distributed as dist from typing import Optional +from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState -class ShardedTensor(object): +class ShardedTensor(StatefulTensor): def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None: r""" A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. """ - self._payload = tensor - self.process_group = process_group - self.world_size = dist.get_world_size(self.process_group) - self.local_rank = dist.get_rank(self.process_group) - self._is_sharded = False + super().__init__(tensor) + self.trans_state(TensorState.HOLD) self._origin_shape = tensor.shape self._origin_numel = tensor.numel() self._origin_dtype = tensor.dtype + self._is_sharded = False + @property - def origin_numel(self): + def origin_numel(self) -> int: return self._origin_numel @property - def origin_shape(self): + def origin_shape(self) -> int: return self._origin_shape @property @@ -34,33 +34,3 @@ class ShardedTensor(object): @is_sharded.setter def is_sharded(self, flag: bool): self._is_sharded = flag - - @property - def payload(self): - return self._payload - - def copy_payload(self, tensor): - self._payload.view(-1).copy_(tensor.view(-1)) - - def reset_payload(self, tensor): - del self._payload - self._payload = tensor - - @property - def device(self): - return self._payload.device - - @property - def dtype(self): - assert self._payload.dtype == self._origin_dtype - return self._origin_dtype - - def to(self, device: torch.device): - raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor") - - def to_(self, device: torch.device): - raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor") - - @property - def shape(self): - return self._payload.shape diff --git a/colossalai/zero/sharded_param/tensorful_state.py b/colossalai/zero/sharded_param/tensorful_state.py new file mode 100644 index 000000000..3be01f1a6 --- /dev/null +++ b/colossalai/zero/sharded_param/tensorful_state.py @@ -0,0 +1,81 @@ +from enum import Enum +from logging import NullHandler +import torch + + +class TensorState(Enum): + FREE = 0 + HOLD = 1 + HOLD_AFTER_FWD = 2 + HOLD_AFTER_BWD = 3 + + +class StatefulTensor(object): + """A Structure stores a Torch Tensor and labeled states. + + PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management + + https://arxiv.org/abs/2108.05818 + """ + + def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: + self._state = state + if state is not TensorState.FREE: + self._payload = tensor + else: + self._payload = None + + def data_ptr(self): + if self._payload is None: + return None + return self._payload.data_ptr() + + @property + def state(self) -> TensorState: + return self._state + + def set_null(self) -> None: + self._state = TensorState.FREE + self._payload = None + + def is_null(self) -> bool: + if self._state == TensorState.FREE: + assert self._payload is None + return True + return False + + def trans_state(self, state: TensorState) -> None: + self._state = state + if state == TensorState.FREE: + self._payload = None + + @property + def payload(self) -> int: + return self._payload + + def copy_payload(self, tensor) -> int: + self._payload.view(-1).copy_(tensor.view(-1)) + + def reset_payload(self, tensor) -> int: + del self._payload + self._payload = tensor + self.trans_state(TensorState.HOLD) + + @property + def device(self) -> torch.device: + return self._payload.device + + @property + def dtype(self) -> torch.dtype: + assert self._payload.dtype == self._origin_dtype + return self._origin_dtype + + def to(self, device: torch.device): + raise RuntimeError("Use colo_model_tensor_move install of call .to() on ShardedTensor") + + def to_(self, device: torch.device): + raise RuntimeError("Use colo_model_tensor_move install of call .to_() on ShardedTensor") + + @property + def shape(self): + return self._payload.shape diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index e2694baf7..5740223d6 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -12,6 +12,7 @@ from colossalai.zero.sharded_param import ShardedTensor from colossalai.zero.sharded_param.sharded_param import ShardedParamV2 from colossalai.testing import rerun_on_exception from tests.test_zero_data_parallel.common import CONFIG, allclose +from colossalai.zero.sharded_param.tensorful_state import StatefulTensor @parameterize("shard_strategy_class", [TensorShardStrategy, BucketTensorShardStrategy]) @@ -52,7 +53,7 @@ def _run_shard_param_v2(rank, world_size, port): allclose(sparam.sharded_data_tensor.payload, param_ref.data) # Test get memory usage - sparam.fp32_grad = torch.randn(2, 3) + sparam.fp32_grad = StatefulTensor(torch.randn(2, 3)) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2, f"cpu_mem_use: {cpu_mem_use}" @@ -62,13 +63,13 @@ def _run_shard_param_v2(rank, world_size, port): # 4 is size of dummy tensor of param.data assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 - sparam.fp16_grad = torch.randn(2, 3).cuda().half() + sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half()) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 assert cuda_mem_use == 2 * 3 * 2 - sparam.fp16_grad = None - sparam.fp32_grad = torch.randn(2, 3) + sparam.fp16_grad = StatefulTensor(None) + sparam.fp32_grad = StatefulTensor(torch.randn(2, 3)) sparam.remove_torch_payload() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 @@ -82,7 +83,7 @@ def _run_shard_param_v2(rank, world_size, port): assert cuda_mem_use == 0 # reuse torch grad for sparam - sparam.fp32_grad = param.grad + sparam.fp32_grad = StatefulTensor(param.grad) cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() assert cpu_mem_use == 2 * 3 * 4 * 2 assert cuda_mem_use == 0