mirror of https://github.com/hpcaitech/ColossalAI
parent
105c5301c3
commit
05e33b2578
|
@ -114,3 +114,20 @@ def colo_model_data_move_to_cpu(t: Union[ShardedTensor, torch.Tensor]) -> None:
|
||||||
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(t_payload)
|
||||||
t_payload.data = t_payload.data.cpu()
|
t_payload.data = t_payload.data.cpu()
|
||||||
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(t_payload)
|
||||||
|
|
||||||
|
|
||||||
|
def colo_model_tensor_clone(t: Union[ShardedTensor, torch.Tensor], target_device: torch.device) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Clone a model data tensor
|
||||||
|
|
||||||
|
Args:
|
||||||
|
t (Union[ShardedTensor, torch.Tensor]): a model data tensor
|
||||||
|
target_device (torch.device): the target device
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: a cloned torch tensor
|
||||||
|
"""
|
||||||
|
t_payload = t.payload if isinstance(t, ShardedTensor) else t
|
||||||
|
|
||||||
|
ret = t_payload.to(target_device)
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(ret)
|
||||||
|
return ret
|
||||||
|
|
|
@ -11,7 +11,7 @@ from colossalai.engine.ophooks import register_ophooks_recursively
|
||||||
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
from colossalai.engine.ophooks.zero_hook import ZeroHook
|
||||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity
|
from colossalai.utils.memory_utils.utils import colo_model_data_move_to_cpu, colo_cuda_memory_capacity, colo_model_tensor_clone
|
||||||
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
|
@ -198,16 +198,16 @@ class ShardedModelV2(nn.Module):
|
||||||
# the shape `grad` is the same as unsharded param
|
# the shape `grad` is the same as unsharded param
|
||||||
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
# So we can just use `view(-1)` to ensure grad is a flat tensor shard
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
grad = p.col_attr.sharded_data_tensor.payload
|
grad_payload = p.col_attr.sharded_data_tensor.payload
|
||||||
else:
|
else:
|
||||||
grad = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
grad_payload = cast_tensor_to_fp32(p.col_attr.fp16_grad)
|
||||||
if p.col_attr.offload_grad:
|
if p.col_attr.offload_grad:
|
||||||
colo_model_data_move_to_cpu(grad)
|
grad_payload = colo_model_tensor_clone(grad_payload, torch.device('cpu'))
|
||||||
if p.col_attr.fp32_grad is not None:
|
if p.col_attr.fp32_grad is not None:
|
||||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
p.col_attr.fp32_grad.add_(grad.view_as(p.col_attr.fp32_grad))
|
p.col_attr.fp32_grad.add_(grad_payload.view_as(p.col_attr.fp32_grad))
|
||||||
grad = p.col_attr.fp32_grad
|
grad_payload = p.col_attr.fp32_grad
|
||||||
p.grad.data = grad
|
p.grad.data = grad_payload
|
||||||
p.col_attr.fp16_grad = None
|
p.col_attr.fp16_grad = None
|
||||||
p.col_attr.fp32_grad = None
|
p.col_attr.fp32_grad = None
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
|
from colossalai.utils.memory_tracer.model_data_memtracer import GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
@ -217,6 +218,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# We must set grad to None
|
# We must set grad to None
|
||||||
# Because we will judge whether local grad accumulation
|
# Because we will judge whether local grad accumulation
|
||||||
# is enabled by wheter grad is None
|
# is enabled by wheter grad is None
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
GLOBAL_MODEL_DATA_TRACER.delete_tensor(p.grad)
|
||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
def sync_grad(self):
|
def sync_grad(self):
|
||||||
|
|
Loading…
Reference in New Issue