mirror of https://github.com/hpcaitech/ColossalAI
[zero] trace states of fp16/32 grad and fp32 param (#571)
parent
7675366fce
commit
7c6c427db1
|
@ -51,9 +51,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
|
|||
"""
|
||||
A colossal API for model data tensor move.
|
||||
The src and target tensors could be resident on both CPU and GPU.
|
||||
|
||||
|
||||
NOTE() The source tensor payload will be removed after this function.
|
||||
|
||||
|
||||
The function will record the communication volume between CPU and GPU.
|
||||
Args:
|
||||
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||
|
@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
|
|||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||
|
||||
if isinstance(target_device, int):
|
||||
target_device = torch.cuda(f'device"{target_device}')
|
||||
target_device = torch.device(f'cuda:{target_device}')
|
||||
|
||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||
if t_payload.device.type == target_device.type:
|
||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
|
|||
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
|
||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module):
|
|||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||
if not self._require_backward_grad_sync:
|
||||
continue
|
||||
# Reduced grad is saved in `p.colo_attr.saved_grad`
|
||||
# It can be on CPU or CUDA
|
||||
# It can be fp16 or fp32
|
||||
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
|
||||
if self.reuse_fp16_shard:
|
||||
grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
|
||||
else:
|
||||
grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
|
||||
assert isinstance(grad_fp16_payload, torch.Tensor)
|
||||
if p.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(grad_fp16_payload)
|
||||
if not p.colo_attr.saved_grad.is_null():
|
||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
# Accumulate grad, saved grad must be fp32
|
||||
p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
|
||||
p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
|
||||
else:
|
||||
p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
|
||||
|
||||
p.grad = None
|
||||
p.colo_attr.fp16_grad.set_null()
|
||||
|
||||
@torch.no_grad()
|
||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
|
@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module):
|
|||
if self.gradient_postdivide_factor > 1:
|
||||
# Average grad by world_size for consistency with PyTorch DDP.
|
||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||
# FIXME(ver217): remove the below line when impl eviction policy
|
||||
if param.colo_attr.offload_grad:
|
||||
colo_model_data_move_to_cpu(reduced_grad)
|
||||
if self.reuse_fp16_shard:
|
||||
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
||||
assert param.colo_attr.saved_grad.is_null(
|
||||
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
|
||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
|
||||
else:
|
||||
param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
||||
reduced_grad = cast_tensor_to_fp32(reduced_grad)
|
||||
if param.colo_attr.saved_grad.is_null():
|
||||
param.colo_attr.saved_grad.reset_payload(reduced_grad)
|
||||
else:
|
||||
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
|
||||
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||
|
||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||
|
|
|
@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
|
|||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||
GLOBAL_MODEL_DATA_TRACER
|
||||
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone,
|
||||
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||
colo_tensor_mem_usage)
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||
from torch import Tensor
|
||||
from torch.distributed import ProcessGroup
|
||||
from torch.nn.parameter import Parameter
|
||||
|
@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||
|
||||
# Store fp32 param shards
|
||||
self.master_params: Dict[Parameter, Tensor] = {}
|
||||
self.master_params: Dict[Parameter, StatefulTensor] = {}
|
||||
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
|
@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# Param is no sharded, which means we use ZeRO-2 here
|
||||
# As we only store param shard, we shard it here
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
|
||||
self.master_params[p] = StatefulTensor(
|
||||
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
|
||||
if not is_param_sharded:
|
||||
# In this branch, there's no need to shard param
|
||||
# So we gather here
|
||||
|
@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self.zero_grad()
|
||||
return
|
||||
|
||||
# assign master param pointers to p.data.
|
||||
# We will not trigger data copy here.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
p.data = self.master_params[p]
|
||||
# Now p.data is sharded
|
||||
# So optimizer states are sharded naturally
|
||||
self._prepare_data()
|
||||
|
||||
self._logger.debug(
|
||||
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
||||
|
@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
self._logger.debug(
|
||||
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
||||
ranks=[0])
|
||||
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# We use ZeRO-2 here
|
||||
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
||||
# But we only have updated fp32 param shard here
|
||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||
# Then we will gather them
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||
|
||||
if not is_param_sharded:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
self._write_back_data()
|
||||
return ret
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
|
@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
# Because we will judge whether local grad accumulation
|
||||
# is enabled by wheter grad is None
|
||||
self.optim.zero_grad(set_to_none=True)
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
||||
def sync_grad(self):
|
||||
pass
|
||||
|
@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
fp32_shards_used_cuda_margin_mem = 0
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
||||
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
|
||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
||||
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
|
||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||
p.colo_attr.offload_grad = False
|
||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||
|
@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
def _prepare_grads(self):
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
|
||||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||
# If we change p.grad directly
|
||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||
# We just set p.data = p.colo_attr.saved_grad.payload here
|
||||
p.data = p.colo_attr.saved_grad.payload
|
||||
p.grad = p.colo_attr.saved_grad.payload
|
||||
# Set p.data to empty tensor, in case of memory leaking
|
||||
p.colo_attr.remove_torch_payload()
|
||||
|
||||
def _prepare_data(self):
|
||||
# assign master param pointers to p.data.
|
||||
# We will not trigger data copy here.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
self.master_params[p].trans_state(TensorState.COMPUTE)
|
||||
p.data = self.master_params[p].payload
|
||||
# Now p.data is sharded
|
||||
# So optimizer states are sharded naturally
|
||||
|
||||
def _write_back_data(self):
|
||||
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||
# a chunk.
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||
if not is_param_sharded:
|
||||
# We use ZeRO-2 here
|
||||
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
||||
# But we only have updated fp32 param shard here
|
||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||
# Then we will gather them
|
||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
# We have to use `copy_payload` instead of `reset_payload`
|
||||
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
||||
|
||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||
|
||||
if not is_param_sharded:
|
||||
# We gather full fp16 param here
|
||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||
self.master_params[p].trans_state(TensorState.HOLD)
|
||||
p.colo_attr.saved_grad.set_null()
|
||||
|
|
|
@ -10,7 +10,6 @@ class ShardedParamV2(object):
|
|||
|
||||
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||
# This attribute must be initialized in ShardedModel
|
||||
self.offload_grad: bool = False
|
||||
|
@ -57,10 +56,6 @@ class ShardedParamV2(object):
|
|||
_update_mem_use(self.sharded_data_tensor.payload)
|
||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||
|
||||
if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.fp16_grad.payload)
|
||||
address_set.add(self.fp16_grad.data_ptr())
|
||||
|
||||
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||
_update_mem_use(self.saved_grad.payload)
|
||||
address_set.add(self.saved_grad.data_ptr())
|
||||
|
|
|
@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
|
|||
# 4 is size of dummy tensor of param.data
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
|
||||
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||
assert cuda_mem_use == 2 * 3 * 2
|
||||
|
||||
sparam.fp16_grad = StatefulTensor(None)
|
||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||
sparam.remove_torch_payload()
|
||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||
|
|
Loading…
Reference in New Issue