mirror of https://github.com/hpcaitech/ColossalAI
parent
105c5301c3
commit
05e33b2578
|
@ -114,3 +114,20 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
|||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||
t_payload.data = t_payload.data.cpu()
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||
|
||||
|
||||
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Clone a model data tensor
|
||||
|
||||
Args:
|
||||
t (Union[ShardedTensor, torch.Tensor]): a model data tensor
|
||||
target_device (torch.device): the target device
|
||||
Returns:
|
||||
torch.Tensor: a cloned torch tensor
|
||||
"""
|
||||
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||
|
||||
ret = t_payload.to(target_device)
|
||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(ret)
|
||||
return ret
|
||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
|
|||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity
|
||||
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone
|
||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
|
@ -198,16 +198,16 @@ class ShardedModelV2(nn.Module):
|
|||
# the shape `grad` is the same as unsharded param
|
||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||
if self.reuse_fp16_shard:
|
||||
grad = p.col_attr.sharded_data_tensor.payload
|
||||
grad_payload = p.col_attr.sharded_data_tensor.payload
|
||||
else:
|
||||
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
||||
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
||||
if p.col_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(grad)
|
||||
grad_payload = colo_model_tensor_clone(grad_payload, torch.device('cpu'))
|
||||
if p.col_attr.fp32_grad is not None:
|
||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
|
||||
grad = p.col_attr.fp32_grad
|
||||
p.grad.data = grad
|
||||
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad))
|
||||
grad_payload = p.col_attr.fp32_grad
|
||||
p.grad.data = grad_payload
|
||||
p.col_attr.fp16_grad = None
|
||||
p.col_attr.fp32_grad = None
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from torch import Tensor
|
||||
|
@ -217,6 +218,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# We must set grad to None
|
||||
# Because we will judge whether local grad accumulation
|
||||
# is enabled by wheter grad is None
|
||||
for group in self.param_groups:
|
||||
for p in group['params']:
|
||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(p.grad)
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
|
||||
def sync_grad(self):
|
||||
|
|
Loading…
Reference in New Issue