[zero] optimize grad offload (#539)

* optimize grad offload

* polish code

* polish code
pull/545/head^2
ver217 2022-03-29 12:48:00 +08:00 committed by GitHub
parent 7d81b5b46e
commit fb841dd5c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 11 deletions

View File

@ -11,9 +11,10 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -206,7 +207,7 @@ class ShardedModelV2(nn.Module):
else: else:
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad) grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_grad: if p.col_attr.offload_grad:
grad_payload = colo_model_tensor_clone(grad_payload, torch.device('cpu')) colo_model_data_move_to_cpu(grad_payload)
if p.col_attr.fp32_grad is not None: if p.col_attr.fp32_grad is not None:
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True' assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad)) p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad))

View File

@ -10,14 +10,14 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
class OptimState(Enum): class OptimState(Enum):
@ -204,7 +204,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16 # Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16) # TODO() optimize this line CPU (fp32) -> GPU (fp16)
colo_model_data_tensor_move(p, p.col_attr.sharded_data_tensor) p.col_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded: if not is_param_sharded:
# We gather full fp16 param here # We gather full fp16 param here