diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index 755f08e9c..d30b786bc 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -6,6 +6,7 @@ from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.zero.shard_utils import BaseShardStrategy +from colossalai.zero.sharded_param.tensorful_state import TensorState from ._base_ophook import BaseOpHook from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move_inline @@ -42,7 +43,13 @@ class ZeroHook(BaseOpHook): if self._memstarts_collector: self._memstarts_collector.sample_memstats() + for param in module.parameters(recurse=False): + param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + def post_fwd_exec(self, module: torch.nn.Module, *args): + for param in module.parameters(recurse=False): + param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD) + tensor_list = [] for param in module.parameters(recurse=False): assert hasattr(param, 'col_attr') @@ -65,7 +72,10 @@ class ZeroHook(BaseOpHook): if param.col_attr.bwd_count == 0: # We haven't stored local accumulated grad yet assert param.col_attr.fp32_grad.is_null() + + # Allocate grad fp32 memory space here param.col_attr.fp32_grad.reset_payload(param.grad.data) + # TODO(jiaruifang) we should set grad fp16 state to HOLD here. param.grad = None else: # We have stored local accumulated grad @@ -75,12 +85,19 @@ class ZeroHook(BaseOpHook): if self._memstarts_collector: self._memstarts_collector.sample_memstats() + for param in module.parameters(recurse=False): + param.col_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE) + def post_bwd_exec(self, module: torch.nn.Module, input): + for param in module.parameters(recurse=False): + param.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) + tensor_list = [] for param in module.parameters(recurse=False): assert hasattr(param, 'col_attr') tensor_list.append(param.col_attr.sharded_data_tensor) self.shard_strategy.shard(tensor_list, self.process_group) + for param in module.parameters(recurse=False): param.col_attr.remove_torch_payload() diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index 03f45ce11..846bfc016 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -25,7 +25,7 @@ from torch.nn.parameter import Parameter from ._utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, free_storage, get_gradient_predivide_factor) -from colossalai.zero.sharded_param.tensorful_state import StatefulTensor +from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState class ShardedModelV2(nn.Module): @@ -158,12 +158,25 @@ class ShardedModelV2(nn.Module): f.write(str(self._memstats_collector.non_model_data_cuda_GB)) f.write('\n') - def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + def _pre_forward_operations(self): if self._iter_cnter == 0 and self._memstats_collector: - # the opeartion will affect the flag in ZeroHook + # the operation will affect the memory tracer behavior in ZeroHook self._memstats_collector.start_collection() + + for p in self.module.parameters(): + if hasattr(p, 'col_attr'): + p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + def _post_forward_operations(self): + for p in self.module.parameters(): + if hasattr(p, 'col_attr'): + p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD) + + def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor: + self._pre_forward_operations() args, kwargs = cast_float_arguments(cast_tensor_to_fp16, *args, **kwargs) outputs = self.module(*args, **kwargs) + self._post_forward_operations() return outputs def backward(self, loss): @@ -195,9 +208,15 @@ class ShardedModelV2(nn.Module): def _post_backward_operations(self) -> None: """ The method includes operations required to be processed after backward + 1. update memory tracer. + 2. flush the gradient in buckets. Reducing partial gradients in each process. + 3. shard tensors not dealed in the zero hook + 4. move sharded param grad payload to param.grad """ + # 1. update memory tracer. self._update_memstats() + # 2. flush the gradient in buckets. Reducing partial gradients in each process. if self._require_backward_grad_sync: # Flush any unreduced buckets in the post_backward stream. with torch.cuda.stream(self.comm_stream): @@ -207,44 +226,50 @@ class ShardedModelV2(nn.Module): # Wait for the non-blocking GPU -> CPU grad transfers to finish. torch.cuda.current_stream().synchronize() self.reducer.free() - # In case some post bwd hook is not fired + # 3. shard tensors not dealed in the zero hook if self.shard_param: tensor_list = [] for p in self.module.parameters(): if not p.col_attr.param_is_sharded: tensor_list.append(p.col_attr.sharded_data_tensor) + p.col_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD) self.shard_strategy.shard(tensor_list, self.process_group) + + # 4. move sharded param grad payload to param.grad for p in self.module.parameters(): p.col_attr.bwd_count = 0 if not p.requires_grad: continue - # Leave the gradient accumulation state as-is if not synchronizing this pass. This ensures p.grad - # remains the unsharded gradient accumulated from prior no-sync passes, and _saved_grad_shard - # remains the sharded gradient from the last synchronized pass. This also allows interleaved no-sync and - # sync passes, if desired. + # Leave the gradient accumulation state (_require_backward_grad_sync) as-is if not synchronizing this pass. + # NOTE() (no-sync)/sync pass: (not conduct)/conduct gradient allreducing between process group. + # If _require_backward_grad_sync is True, + # p.grad remains the accumulated unsharded gradient from prior no-sync passes. + # We also allows to interleave no-sync pass with sync passes, if desired. if not self._require_backward_grad_sync: continue - # Write grad back to p.grad and set p.col_attr.grad to None + # Write grad payload kept by sharded param back to p.grad, + # and set p.col_attr.grad to None # As sharded optimizer only update a shard of param, # no matter whether we shard param in sharded model # We have to make sure the grad is a flat tensor shard - # If world size == 1 and sharded param, + # If world size == 1 and param is sharded, # the shape `grad` is the same as unsharded param # So we can just use `view(-1)` to ensure grad is a flat tensor shard if self.reuse_fp16_shard: - grad_payload = p.col_attr.sharded_data_tensor.payload + grad_fp16_payload = p.col_attr.sharded_data_tensor.payload else: - grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload) - assert isinstance(grad_payload, torch.Tensor) + grad_fp16_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad.payload) + assert isinstance(grad_fp16_payload, torch.Tensor) if p.col_attr.offload_grad: - colo_model_data_move_to_cpu(grad_payload) + colo_model_data_move_to_cpu(grad_fp16_payload) if not p.col_attr.fp32_grad.is_null(): assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' - p.col_attr.fp32_grad.payload.add_(grad_payload.view_as(p.col_attr.fp32_grad.payload)) - grad_payload = p.col_attr.fp32_grad.payload - p.grad.data = grad_payload + p.col_attr.fp32_grad.payload.add_(grad_fp16_payload.view_as(p.col_attr.fp32_grad.payload)) + grad_fp16_payload = p.col_attr.fp32_grad.payload + p.col_attr.fp32_grad.set_null() + + p.grad.data = grad_fp16_payload p.col_attr.fp16_grad.set_null() - p.col_attr.fp32_grad.set_null() @torch.no_grad() def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: diff --git a/colossalai/zero/sharded_optim/sharded_optim_v2.py b/colossalai/zero/sharded_optim/sharded_optim_v2.py index 90d908044..fa3b2daa4 100644 --- a/colossalai/zero/sharded_optim/sharded_optim_v2.py +++ b/colossalai/zero/sharded_optim/sharded_optim_v2.py @@ -79,7 +79,6 @@ class ShardedOptimizerV2(ColossalaiOptimizer): growth_interval: float = 1000, hysteresis: float = 2, max_scale: int = 2**32, - use_memory_tracer=False, dp_process_group: Optional[ProcessGroup] = None, mp_process_group: Optional[ProcessGroup] = None) -> None: assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index b25e2e8a0..71e8030ac 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -8,11 +8,8 @@ from .tensorful_state import StatefulTensor, TensorState class ShardedParamV2(object): - def __init__(self, - param: torch.nn.Parameter, - process_group: Optional[dist.ProcessGroup] = None, - rm_torch_payload=False) -> None: - self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data, process_group) + def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None: + self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) self.fp32_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) # This attribute must be initialized in ShardedModel diff --git a/colossalai/zero/sharded_param/sharded_tensor.py b/colossalai/zero/sharded_param/sharded_tensor.py index 59dc899ed..8e799c314 100644 --- a/colossalai/zero/sharded_param/sharded_tensor.py +++ b/colossalai/zero/sharded_param/sharded_tensor.py @@ -1,22 +1,20 @@ import torch -import torch.distributed as dist -from typing import Optional from colossalai.zero.sharded_param.tensorful_state import StatefulTensor, TensorState +from typing import Optional class ShardedTensor(StatefulTensor): - def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None: + def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: r""" A tensor sharded in multiple processes. Constructed from an existing torch.Tensor instance. """ - super().__init__(tensor) - self.trans_state(TensorState.HOLD) + super().__init__(tensor, state) + # kept the shape, numel and dtype of the init tensor. self._origin_shape = tensor.shape self._origin_numel = tensor.numel() self._origin_dtype = tensor.dtype - self._is_sharded = False @property diff --git a/colossalai/zero/sharded_param/tensorful_state.py b/colossalai/zero/sharded_param/tensorful_state.py index 3be01f1a6..5bde388be 100644 --- a/colossalai/zero/sharded_param/tensorful_state.py +++ b/colossalai/zero/sharded_param/tensorful_state.py @@ -1,5 +1,5 @@ from enum import Enum -from logging import NullHandler +from typing import Optional import torch @@ -8,22 +8,22 @@ class TensorState(Enum): HOLD = 1 HOLD_AFTER_FWD = 2 HOLD_AFTER_BWD = 3 + COMPUTE = 4 class StatefulTensor(object): - """A Structure stores a Torch Tensor and labeled states. - + """A Structure stores a Torch Tensor and labeled states. + Inspired from the paper: PatrickStar: Parallel Training of Pre-trained Models via Chunk-based Memory Management https://arxiv.org/abs/2108.05818 """ - def __init__(self, tensor: torch.Tensor, state: TensorState = TensorState.HOLD) -> None: + def __init__(self, tensor: torch.Tensor, state: Optional[TensorState] = TensorState.HOLD) -> None: self._state = state - if state is not TensorState.FREE: - self._payload = tensor - else: - self._payload = None + self._payload = tensor + if self._state == TensorState.FREE: + assert self._payload is None, f"payload has to None if {self._state}" def data_ptr(self): if self._payload is None: diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 5740223d6..56c77fdee 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -48,7 +48,7 @@ def _run_shard_param_v2(rank, world_size, port): param = torch.nn.Parameter(torch.randn(2, 3)) param_ref = deepcopy(param) - sparam = ShardedParamV2(param=param, process_group=None) + sparam = ShardedParamV2(param=param) allclose(sparam.sharded_data_tensor.payload, param_ref.data)