[zero] trace states of fp16/32 grad and fp32 param (#571)

pull/561/head
ver217 2022-03-31 16:26:54 +08:00 committed by GitHub
parent 7675366fce
commit 7c6c427db1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 69 additions and 73 deletions

View File

@ -51,9 +51,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
""" """
A colossal API for model data tensor move. A colossal API for model data tensor move.
The src and target tensors could be resident on both CPU and GPU. The src and target tensors could be resident on both CPU and GPU.
NOTE() The source tensor payload will be removed after this function. NOTE() The source tensor payload will be removed after this function.
The function will record the communication volume between CPU and GPU. The function will record the communication volume between CPU and GPU.
Args: Args:
t_src (Union[StatefulTensor, torch.Tensor]): source tensor t_src (Union[StatefulTensor, torch.Tensor]): source tensor
@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}') raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
if isinstance(target_device, int): if isinstance(target_device, int):
target_device = torch.cuda(f'device"{target_device}') target_device = torch.device(f'cuda:{target_device}')
# deal with torch.device('cpu') and torch.device('cpu:0) # deal with torch.device('cpu') and torch.device('cpu:0)
if t_payload.device.type == target_device.type: if t_payload.device.type == target_device.type:

View File

@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu) from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState) from colossalai.zero.sharded_param.tensorful_state import TensorState
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module):
# We also allows to interleave no-sync pass with sync passes, if desired. # We also allows to interleave no-sync pass with sync passes, if desired.
if not self._require_backward_grad_sync: if not self._require_backward_grad_sync:
continue continue
# Reduced grad is saved in `p.colo_attr.saved_grad`
# It can be on CPU or CUDA
# It can be fp16 or fp32
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
if self.reuse_fp16_shard:
grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
else:
grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
assert isinstance(grad_fp16_payload, torch.Tensor)
if p.colo_attr.offload_grad:
colo_model_data_move_to_cpu(grad_fp16_payload)
if not p.colo_attr.saved_grad.is_null():
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
# Accumulate grad, saved grad must be fp32
p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
else:
p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
p.grad = None p.grad = None
p.colo_attr.fp16_grad.set_null()
@torch.no_grad() @torch.no_grad()
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module):
if self.gradient_postdivide_factor > 1: if self.gradient_postdivide_factor > 1:
# Average grad by world_size for consistency with PyTorch DDP. # Average grad by world_size for consistency with PyTorch DDP.
reduced_grad.data.div_(self.gradient_postdivide_factor) reduced_grad.data.div_(self.gradient_postdivide_factor)
# FIXME(ver217): remove the below line when impl eviction policy
if param.colo_attr.offload_grad:
colo_model_data_move_to_cpu(reduced_grad)
if self.reuse_fp16_shard: if self.reuse_fp16_shard:
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data) assert param.colo_attr.saved_grad.is_null(
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
param.colo_attr.sharded_data_tensor.is_sharded = True param.colo_attr.sharded_data_tensor.is_sharded = True
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
else: else:
param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data) reduced_grad = cast_tensor_to_fp32(reduced_grad)
if param.colo_attr.saved_grad.is_null():
param.colo_attr.saved_grad.reset_payload(reduced_grad)
else:
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]': def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()], self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],

View File

@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils.memory_tracer.model_data_memtracer import \ from colossalai.utils.memory_tracer.model_data_memtracer import \
GLOBAL_MODEL_DATA_TRACER GLOBAL_MODEL_DATA_TRACER
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone, from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
colo_tensor_mem_usage) colo_tensor_mem_usage)
from colossalai.zero.sharded_model import ShardedModelV2 from colossalai.zero.sharded_model import ShardedModelV2
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32 from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
from colossalai.zero.sharded_optim._utils import has_inf_or_nan from colossalai.zero.sharded_optim._utils import has_inf_or_nan
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
from torch import Tensor from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger = get_dist_logger("ShardedOptimizerV2") self._logger = get_dist_logger("ShardedOptimizerV2")
# Store fp32 param shards # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, StatefulTensor] = {}
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Param is no sharded, which means we use ZeRO-2 here # Param is no sharded, which means we use ZeRO-2 here
# As we only store param shard, we shard it here # As we only store param shard, we shard it here
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group) self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device) self.master_params[p] = StatefulTensor(
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
if not is_param_sharded: if not is_param_sharded:
# In this branch, there's no need to shard param # In this branch, there's no need to shard param
# So we gather here # So we gather here
@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self.zero_grad() self.zero_grad()
return return
# assign master param pointers to p.data. self._prepare_data()
# We will not trigger data copy here.
for group in self.optim.param_groups:
for p in group['params']:
p.data = self.master_params[p]
# Now p.data is sharded
# So optimizer states are sharded naturally
self._logger.debug( self._logger.debug(
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
self._logger.debug( self._logger.debug(
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!", f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
ranks=[0]) ranks=[0])
# Copy master param data (fp32) to payload of colo_attr (fp16) self._write_back_data()
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload
return ret return ret
def backward(self, loss: Tensor) -> None: def backward(self, loss: Tensor) -> None:
@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Because we will judge whether local grad accumulation # Because we will judge whether local grad accumulation
# is enabled by wheter grad is None # is enabled by wheter grad is None
self.optim.zero_grad(set_to_none=True) self.optim.zero_grad(set_to_none=True)
for group in self.optim.param_groups:
for p in group['params']:
p.colo_attr.saved_grad.set_null()
def sync_grad(self): def sync_grad(self):
pass pass
@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
fp32_shards_used_cuda_margin_mem = 0 fp32_shards_used_cuda_margin_mem = 0
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size() shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem: if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device()) colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
p.grad.data = p.grad.data.to(torch.cuda.current_device()) p.grad.data = p.grad.data.to(torch.cuda.current_device())
p.colo_attr.offload_grad = False p.colo_attr.offload_grad = False
fp32_shards_used_cuda_margin_mem += shard_mem fp32_shards_used_cuda_margin_mem += shard_mem
@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
def _prepare_grads(self): def _prepare_grads(self):
for group in self.optim.param_groups: for group in self.optim.param_groups:
for p in group['params']: for p in group['params']:
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation # FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
# If we change p.grad directly # If we change p.grad directly
# it may raise error because of different shape/dtype/device of p.data and p.grad # it may raise error because of different shape/dtype/device of p.data and p.grad
# We just set p.data = p.colo_attr.saved_grad.payload here # We just set p.data = p.colo_attr.saved_grad.payload here
p.data = p.colo_attr.saved_grad.payload p.data = p.colo_attr.saved_grad.payload
p.grad = p.colo_attr.saved_grad.payload p.grad = p.colo_attr.saved_grad.payload
# Set p.data to empty tensor, in case of memory leaking
p.colo_attr.remove_torch_payload()
def _prepare_data(self):
# assign master param pointers to p.data.
# We will not trigger data copy here.
for group in self.optim.param_groups:
for p in group['params']:
self.master_params[p].trans_state(TensorState.COMPUTE)
p.data = self.master_params[p].payload
# Now p.data is sharded
# So optimizer states are sharded naturally
def _write_back_data(self):
# Copy master param data (fp32) to payload of colo_attr (fp16)
# TODO() improve efficiency by gathering tensors into a chunk and transfering
# a chunk.
for group in self.optim.param_groups:
for p in group['params']:
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
if not is_param_sharded:
# We use ZeRO-2 here
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
# But we only have updated fp32 param shard here
# So we first shard full fp16 param and copy fp32 param shard to it
# Then we will gather them
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
# We have to use `copy_payload` instead of `reset_payload`
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
p.colo_attr.sharded_data_tensor.reset_payload(
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
if not is_param_sharded:
# We gather full fp16 param here
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
p.data = p.colo_attr.sharded_data_tensor.payload
self.master_params[p].trans_state(TensorState.HOLD)
p.colo_attr.saved_grad.set_null() p.colo_attr.saved_grad.set_null()

View File

@ -10,7 +10,6 @@ class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None: def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data) self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE) self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
# This attribute must be initialized in ShardedModel # This attribute must be initialized in ShardedModel
self.offload_grad: bool = False self.offload_grad: bool = False
@ -57,10 +56,6 @@ class ShardedParamV2(object):
_update_mem_use(self.sharded_data_tensor.payload) _update_mem_use(self.sharded_data_tensor.payload)
address_set.add(self.sharded_data_tensor.payload.data_ptr()) address_set.add(self.sharded_data_tensor.payload.data_ptr())
if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set:
_update_mem_use(self.fp16_grad.payload)
address_set.add(self.fp16_grad.data_ptr())
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set: if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
_update_mem_use(self.saved_grad.payload) _update_mem_use(self.saved_grad.payload)
address_set.add(self.saved_grad.data_ptr()) address_set.add(self.saved_grad.data_ptr())

View File

@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
# 4 is size of dummy tensor of param.data # 4 is size of dummy tensor of param.data
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4 assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
assert cuda_mem_use == 2 * 3 * 2
sparam.fp16_grad = StatefulTensor(None)
sparam.saved_grad = StatefulTensor(torch.randn(2, 3)) sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
sparam.remove_torch_payload() sparam.remove_torch_payload()
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage() cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()