|
|
|
@ -1,5 +1,5 @@
|
|
|
|
|
from ast import Try |
|
|
|
|
import functools |
|
|
|
|
from ast import Try |
|
|
|
|
from collections import OrderedDict |
|
|
|
|
from typing import Any, Optional |
|
|
|
|
|
|
|
|
@ -12,16 +12,17 @@ from colossalai.engine.ophooks import register_ophooks_recursively
|
|
|
|
|
from colossalai.engine.ophooks.zero_hook import ZeroHook |
|
|
|
|
from colossalai.engine.paramhooks import BaseParamHookMgr |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.utils.commons.memory import col_cuda_memory_capacity |
|
|
|
|
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu |
|
|
|
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector |
|
|
|
|
from colossalai.zero.shard_utils import BaseShardStrategy |
|
|
|
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer |
|
|
|
|
from colossalai.zero.sharded_param import ShardedParamV2 |
|
|
|
|
from torch.distributed import ProcessGroup |
|
|
|
|
from torch.nn.parameter import Parameter |
|
|
|
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector |
|
|
|
|
from colossalai.utils.memory_tracer.allocator import col_move_to_cpu |
|
|
|
|
|
|
|
|
|
from ._zero3_utils import (cast_float_arguments, cast_tensor_to_fp16, cast_tensor_to_fp32, chunk_and_pad, |
|
|
|
|
get_gradient_predivide_factor) |
|
|
|
|
from colossalai.utils.commons.memory import col_cuda_memory_capacity |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardedModelV2(nn.Module): |
|
|
|
@ -164,8 +165,15 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
# If world size == 1 and sharded param, |
|
|
|
|
# the shape `grad` is the same as unsharded param |
|
|
|
|
# So we can just use `view(-1)` to ensure grad is a flat tensor shard |
|
|
|
|
p.grad.data = p.col_attr.grad.view(-1) |
|
|
|
|
p.col_attr.grad = None |
|
|
|
|
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad) |
|
|
|
|
if self._cpu_offload: |
|
|
|
|
col_move_to_cpu(grad) |
|
|
|
|
if p.col_attr.fp32_grad is not None: |
|
|
|
|
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad)) |
|
|
|
|
grad = p.col_attr.fp32_grad |
|
|
|
|
p.grad.data = grad.view(-1) |
|
|
|
|
p.col_attr.fp16_grad = None |
|
|
|
|
p.col_attr.fp32_grad = None |
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
|
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: |
|
|
|
@ -216,23 +224,7 @@ class ShardedModelV2(nn.Module):
|
|
|
|
|
# Average grad by world_size for consistency with PyTorch DDP. |
|
|
|
|
reduced_grad.data.div_(self.gradient_postdivide_factor) |
|
|
|
|
|
|
|
|
|
# Make sure we store fp32 grad |
|
|
|
|
reduced_grad.data = cast_tensor_to_fp32(reduced_grad.data) |
|
|
|
|
|
|
|
|
|
# Maybe offload |
|
|
|
|
# TODO() optimize GPU->CPU bandwidth utilization |
|
|
|
|
if self._cpu_offload: |
|
|
|
|
col_move_to_cpu(reduced_grad) |
|
|
|
|
# reduced_grad.data = reduced_grad.data.cpu() |
|
|
|
|
|
|
|
|
|
if param.col_attr.grad is None: |
|
|
|
|
param.col_attr.grad = reduced_grad.data |
|
|
|
|
else: |
|
|
|
|
# When dp size = 1 |
|
|
|
|
# param.col_attr.grad is local accumulated grad shard (full but flatten) |
|
|
|
|
# But reduced_grad here is full grad |
|
|
|
|
# We should call `view_as` |
|
|
|
|
param.col_attr.grad.add_(reduced_grad.data.view_as(param.col_attr.grad)) |
|
|
|
|
param.col_attr.fp16_grad = reduced_grad.data |
|
|
|
|
|
|
|
|
|
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': |
|
|
|
|
self.shard_strategy.gather([p.col_attr.data for p in self.module.parameters()]) |
|
|
|
|