[zero] trace states of fp16/32 grad and fp32 param (#571)

pull/561/head
ver217 3 years ago committed by GitHub
parent 7675366fce
commit 7c6c427db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -51,9 +51,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
"""
A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU.
Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if isinstance(target_device, int):
target_device = torch.cuda(f'device"{target_device}')
target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type:

@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync:
continue
# Reduced grad is saved in `p.colo_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if self.reuse_fp16_shard:
grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
else:
grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
assert isinstance(grad_fp16_payload, torch.Tensor)
if p.colo_attr.offload_grad:
colo_model_data_move_to_cpu(grad_fp16_payload)
if not p.colo_attr.saved_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32
p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
else:
p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
p.grad = None
p.colo_attr.fp16_grad.set_null()
@torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module):
if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor)
# FIXME(ver217): remove the below line when impl eviction policy
if param.colo_attr.offload_grad:
colo_model_data_move_to_cpu(reduced_grad)
if self.reuse_fp16_shard:
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
param.colo_attr.sharded_data_tensor.is_sharded = True
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
else:
param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data)
reduced_grad = cast_tensor_to_fp32(reduced_grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.saved_grad.reset_payload(reduced_grad)
else:
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],

@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone,
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from torch import Tensor
from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter
@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger = get_dist_logger("ShardedOptimizerV2")
# Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {}
self.master_params: Dict[Parameter, StatefulTensor] = {}
for group in self.optim.param_groups:
for p in group['params']:
@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
if not is_param_sharded:
# In this branch, there's no need to shard param
# So we gather here
@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.zero_grad()
return
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
for p in group['params']:
p.data = self.master_params[p]
# Now p.data is sharded
# So optimizer states are sharded naturally
self._prepare_data()
self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
ranks=[0])
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload
self._write_back_data()
return ret
def backward(self, loss: Tensor) -> None:
@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation
# is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True)
for group in self.optim.param_groups:
for p in group['params']:
p.colo_attr.saved_grad.set_null()
def sync_grad(self):
pass
@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups:
for p in group['params']:
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.colo_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem
@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _prepare_grads(self):
for group in self.optim.param_groups:
for p in group['params']:
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.colo_attr.saved_grad.payload
p.grad = p.colo_attr.saved_grad.payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.remove_torch_payload()
def _prepare_data(self):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
for p in group['params']:
self.master_params[p].trans_state(TensorState.COMPUTE)
p.data = self.master_params[p].payload
# Now p.data is sharded
# So optimizer states are sharded naturally
def _write_back_data(self):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload
self.master_params[p].trans_state(TensorState.HOLD)
p.colo_attr.saved_grad.set_null()

@ -10,7 +10,6 @@ class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel
self.offload_grad: bool = False
@ -57,10 +56,6 @@ class ShardedParamV2(object):
_update_mem_use(self.sharded_data_tensor.payload)
address_set.add(self.sharded_data_tensor.payload.data_ptr())
if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set:
_update_mem_use(self.fp16_grad.payload)
address_set.add(self.fp16_grad.data_ptr())
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
_update_mem_use(self.saved_grad.payload)
address_set.add(self.saved_grad.data_ptr())

@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
# 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = StatefulTensor(None)
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()

Loading…
Cancel
Save