[zero] optimize grad offload (#539)

* optimize grad offload

* polish code

* polish code
pull/545/head^2
ver217 2022-03-29 12:48:00 +08:00 committed by GitHub
parent 7d81b5b46e
commit fb841dd5c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 13 additions and 11 deletions

View File

@ -11,9 +11,10 @@ from colossalai.engine.ophooks import register_ophooks_recursively
from colossalai.engine.ophooks.zero_hook import ZeroHook
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from torch.distributed import ProcessGroup
@ -206,7 +207,7 @@ class ShardedModelV2(nn.Module):
else:
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad)
if p.col_attr.offload_grad:
grad_payload = colo_model_tensor_clone(grad_payload, torch.device('cpu'))
colo_model_data_move_to_cpu(grad_payload)
if p.col_attr.fp32_grad is not None:
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad))

View File

@ -10,14 +10,14 @@ from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_utils.utils import (colo_model_tensor_clone, colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
from torch.optim import Optimizer
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.utils.memory_utils.utils import colo_model_data_tensor_move, colo_tensor_mem_usage
class OptimState(Enum):
@ -204,7 +204,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Since p.data is fp32 and p.col_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
colo_model_data_tensor_move(p, p.col_attr.sharded_data_tensor)
p.col_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
# We gather full fp16 param here