diff --git a/colossalai/engine/ophooks/zero_hook.py b/colossalai/engine/ophooks/zero_hook.py index 0ef826536..a4df0f502 100644 --- a/colossalai/engine/ophooks/zero_hook.py +++ b/colossalai/engine/ophooks/zero_hook.py @@ -1,12 +1,14 @@ +from typing import Optional + import torch from colossalai.registry import OPHOOKS from colossalai.utils import get_current_device +from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector +from colossalai.utils.memory_tracer.model_data_memtracer import \ + GLOBAL_MODEL_DATA_TRACER from colossalai.zero.shard_utils import BaseShardStrategy from ._base_ophook import BaseOpHook -from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector -from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER -from typing import Optional @OPHOOKS.register_module @@ -62,8 +64,8 @@ class ZeroHook(BaseOpHook): if param.grad is not None: if param.col_attr.bwd_count == 0: # We haven't stored local accumulated grad yet - assert param.col_attr.grad is None - param.col_attr.grad = param.grad.data + assert param.col_attr.fp32_grad is None + param.col_attr.fp32_grad = param.grad.data param.grad = None else: # We have stored local accumulated grad diff --git a/colossalai/zero/init_ctx/init_context.py b/colossalai/zero/init_ctx/init_context.py index 513a5e3f8..242208eca 100644 --- a/colossalai/zero/init_ctx/init_context.py +++ b/colossalai/zero/init_ctx/init_context.py @@ -1,9 +1,10 @@ import functools import torch +from colossalai.utils.memory_tracer.model_data_memtracer import \ + GLOBAL_MODEL_DATA_TRACER from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_param import ShardedParamV2 -from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER # Inserts _post_init_method at the end of init method @@ -154,6 +155,6 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses): if self.shard_param: self.shard_strategy.shard(tensor_list=[param.col_attr._data_sharded_tensor]) GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._data_sharded_tensor.payload) - if param.col_attr.grad and self.shard_grad: - self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) - GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) + # if param.col_attr.grad and self.shard_grad: + # self.shard_strategy.shard(tensor_list=[param.col_attr._grad_sharded_tensor]) + # GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr._grad_sharded_tensor.payload) diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index f92107e6c..2de39de4a 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,5 +1,5 @@ -from ast import Try import functools +from ast import Try from collections import OrderedDict from typing import Any, Optional @@ -12,16 +12,17 @@ from colossalai.engine.ophooks import register_ophooks_recursively from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger +from colossalai.utils.commons.memory import col_cuda_memory_capacity +from colossalai.utils.memory_tracer.allocator import col_move_to_cpu +from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_param import ShardedParamV2 from torch.distributed import ProcessGroup from torch.nn.parameter import Parameter -from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector -from colossalai.utils.memory_tracer.allocator import col_move_to_cpu + from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, get_gradient_predivide_factor) -from colossalai.utils.commons.memory import col_cuda_memory_capacity class ShardedModelV2(nn.Module): @@ -164,8 +165,15 @@ class ShardedModelV2(nn.Module): # If world size == 1 and sharded param, # the shape `grad` is the same as unsharded param # So we can just use `view(-1)` to ensure grad is a flat tensor shard - p.grad.data = p.col_attr.grad.view(-1) - p.col_attr.grad = None + grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) + if self._cpu_offload: + col_move_to_cpu(grad) + if p.col_attr.fp32_grad is not None: + p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad)) + grad = p.col_attr.fp32_grad + p.grad.data = grad.view(-1) + p.col_attr.fp16_grad = None + p.col_attr.fp32_grad = None @torch.no_grad() def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: @@ -216,23 +224,7 @@ class ShardedModelV2(nn.Module): # Average grad by world_size for consistency with PyTorch DDP. reduced_grad.data.div_(self.gradient_postdivide_factor) - # Make sure we store fp32 grad - reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data) - - # Maybe offload - # TODO() optimize GPU->CPU bandwidth utilization - if self._cpu_offload: - col_move_to_cpu(reduced_grad) - # reduced_grad.data = reduced_grad.data.cpu() - - if param.col_attr.grad is None: - param.col_attr.grad = reduced_grad.data - else: - # When dp size = 1 - # param.col_attr.grad is local accumulated grad shard (full but flatten) - # But reduced_grad here is full grad - # We should call `view_as` - param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad)) + param.col_attr.fp16_grad = reduced_grad.data def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()]) diff --git a/colossalai/zero/sharded_param/sharded_param.py b/colossalai/zero/sharded_param/sharded_param.py index 3c90fda64..01987c1e2 100644 --- a/colossalai/zero/sharded_param/sharded_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -16,7 +16,8 @@ class ShardedParamV2(object): process_group: Optional[dist.ProcessGroup] = None, rm_torch_payload=False) -> None: self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group) - self._grad_sharded_tensor: Optional[torch.Tensor] = None + self.fp16_grad: Optional[torch.Tensor] = None + self.fp32_grad: Optional[torch.Tensor] = None # make sure the shared param is the only owner of payload # The param.data maybe used to init the other part of the model. @@ -39,14 +40,6 @@ class ShardedParamV2(object): def data(self): return self._data_sharded_tensor - @property - def grad(self): - return self._grad_sharded_tensor - - @grad.setter - def grad(self, t: torch.Tensor): - self._grad_sharded_tensor = t - @property def param_is_sharded(self): return self._data_sharded_tensor.is_sharded