mirror of https://github.com/hpcaitech/ColossalAI
[zero] trace states of fp16/32 grad and fp32 param (#571)
parent
7675366fce
commit
7c6c427db1
|
@ -51,9 +51,9 @@ def colo_model_data_tensor_move(src_t: Union[StatefulTensor, torch.Tensor], tgt_
|
||||||
"""
|
"""
|
||||||
A colossal API for model data tensor move.
|
A colossal API for model data tensor move.
|
||||||
The src and target tensors could be resident on both CPU and GPU.
|
The src and target tensors could be resident on both CPU and GPU.
|
||||||
|
|
||||||
NOTE() The source tensor payload will be removed after this function.
|
NOTE() The source tensor payload will be removed after this function.
|
||||||
|
|
||||||
The function will record the communication volume between CPU and GPU.
|
The function will record the communication volume between CPU and GPU.
|
||||||
Args:
|
Args:
|
||||||
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
|
t_src (Union[StatefulTensor, torch.Tensor]): source tensor
|
||||||
|
@ -93,7 +93,7 @@ def colo_model_data_tensor_move_inline(t: Union[StatefulTensor, torch.Tensor], t
|
||||||
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
raise TypeError('colo_model_data_move_to_cpu dose not accept type {type(t)}')
|
||||||
|
|
||||||
if isinstance(target_device, int):
|
if isinstance(target_device, int):
|
||||||
target_device = torch.cuda(f'device"{target_device}')
|
target_device = torch.device(f'cuda:{target_device}')
|
||||||
|
|
||||||
# deal with torch.device('cpu') and torch.device('cpu:0)
|
# deal with torch.device('cpu') and torch.device('cpu:0)
|
||||||
if t_payload.device.type == target_device.type:
|
if t_payload.device.type == target_device.type:
|
||||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
|
from colossalai.utils.memory_utils.utils import (colo_cuda_memory_capacity, colo_model_data_move_to_cpu)
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
from colossalai.zero.sharded_param.tensorful_state import TensorState
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
|
@ -245,27 +245,7 @@ class ShardedModelV2(nn.Module):
|
||||||
# We also allows to interleave no-sync pass with sync passes, if desired.
|
# We also allows to interleave no-sync pass with sync passes, if desired.
|
||||||
if not self._require_backward_grad_sync:
|
if not self._require_backward_grad_sync:
|
||||||
continue
|
continue
|
||||||
# Reduced grad is saved in `p.colo_attr.saved_grad`
|
|
||||||
# It can be on CPU or CUDA
|
|
||||||
# It can be fp16 or fp32
|
|
||||||
# We set `p.grad` to None here and ShardedOptimizer will prepare `p.grad` before `step()`.
|
|
||||||
if self.reuse_fp16_shard:
|
|
||||||
grad_fp16_payload = p.colo_attr.sharded_data_tensor.payload
|
|
||||||
else:
|
|
||||||
grad_fp16_payload = cast_tensor_to_fp32(p.colo_attr.fp16_grad.payload)
|
|
||||||
assert isinstance(grad_fp16_payload, torch.Tensor)
|
|
||||||
if p.colo_attr.offload_grad:
|
|
||||||
colo_model_data_move_to_cpu(grad_fp16_payload)
|
|
||||||
if not p.colo_attr.saved_grad.is_null():
|
|
||||||
assert not self.reuse_fp16_shard, 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
|
||||||
# Accumulate grad, saved grad must be fp32
|
|
||||||
p.colo_attr.saved_grad.reset_payload(cast_tensor_to_fp32(p.colo_attr.saved_grad.payload))
|
|
||||||
p.colo_attr.saved_grad.payload.add_(grad_fp16_payload.view_as(p.colo_attr.saved_grad.payload))
|
|
||||||
else:
|
|
||||||
p.colo_attr.saved_grad.reset_payload(grad_fp16_payload)
|
|
||||||
|
|
||||||
p.grad = None
|
p.grad = None
|
||||||
p.colo_attr.fp16_grad.set_null()
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
|
@ -322,11 +302,22 @@ class ShardedModelV2(nn.Module):
|
||||||
if self.gradient_postdivide_factor > 1:
|
if self.gradient_postdivide_factor > 1:
|
||||||
# Average grad by world_size for consistency with PyTorch DDP.
|
# Average grad by world_size for consistency with PyTorch DDP.
|
||||||
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
reduced_grad.data.div_(self.gradient_postdivide_factor)
|
||||||
|
# FIXME(ver217): remove the below line when impl eviction policy
|
||||||
|
if param.colo_attr.offload_grad:
|
||||||
|
colo_model_data_move_to_cpu(reduced_grad)
|
||||||
if self.reuse_fp16_shard:
|
if self.reuse_fp16_shard:
|
||||||
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad.data)
|
assert param.colo_attr.saved_grad.is_null(
|
||||||
|
), 'Gradien accumulation is not supported when reuse_fp16_shard=True'
|
||||||
|
param.colo_attr.sharded_data_tensor.reset_payload(reduced_grad)
|
||||||
param.colo_attr.sharded_data_tensor.is_sharded = True
|
param.colo_attr.sharded_data_tensor.is_sharded = True
|
||||||
|
param.colo_attr.saved_grad.reset_payload(param.colo_attr.sharded_data_tensor.payload)
|
||||||
else:
|
else:
|
||||||
param.colo_attr.fp16_grad = StatefulTensor(reduced_grad.data)
|
reduced_grad = cast_tensor_to_fp32(reduced_grad)
|
||||||
|
if param.colo_attr.saved_grad.is_null():
|
||||||
|
param.colo_attr.saved_grad.reset_payload(reduced_grad)
|
||||||
|
else:
|
||||||
|
param.colo_attr.saved_grad.payload.add_(reduced_grad.view_as(param.colo_attr.saved_grad.payload))
|
||||||
|
param.colo_attr.saved_grad.trans_state(TensorState.HOLD)
|
||||||
|
|
||||||
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
def state_dict(self, destination=None, prefix='', keep_vars=False) -> 'OrderedDict[str, torch.Tensor]':
|
||||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor for p in self.module.parameters()],
|
||||||
|
|
|
@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.optimizer import ColossalaiOptimizer
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
||||||
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
||||||
GLOBAL_MODEL_DATA_TRACER
|
GLOBAL_MODEL_DATA_TRACER
|
||||||
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone,
|
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
||||||
colo_tensor_mem_usage)
|
colo_tensor_mem_usage)
|
||||||
from colossalai.zero.sharded_model import ShardedModelV2
|
from colossalai.zero.sharded_model import ShardedModelV2
|
||||||
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
||||||
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
||||||
|
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self._logger = get_dist_logger("ShardedOptimizerV2")
|
self._logger = get_dist_logger("ShardedOptimizerV2")
|
||||||
|
|
||||||
# Store fp32 param shards
|
# Store fp32 param shards
|
||||||
self.master_params: Dict[Parameter, Tensor] = {}
|
self.master_params: Dict[Parameter, StatefulTensor] = {}
|
||||||
|
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
|
@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# Param is no sharded, which means we use ZeRO-2 here
|
# Param is no sharded, which means we use ZeRO-2 here
|
||||||
# As we only store param shard, we shard it here
|
# As we only store param shard, we shard it here
|
||||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
|
self.master_params[p] = StatefulTensor(
|
||||||
|
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
|
||||||
if not is_param_sharded:
|
if not is_param_sharded:
|
||||||
# In this branch, there's no need to shard param
|
# In this branch, there's no need to shard param
|
||||||
# So we gather here
|
# So we gather here
|
||||||
|
@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return
|
return
|
||||||
|
|
||||||
# assign master param pointers to p.data.
|
self._prepare_data()
|
||||||
# We will not trigger data copy here.
|
|
||||||
for group in self.optim.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
p.data = self.master_params[p]
|
|
||||||
# Now p.data is sharded
|
|
||||||
# So optimizer states are sharded naturally
|
|
||||||
|
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
||||||
|
@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
self._logger.debug(
|
self._logger.debug(
|
||||||
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
||||||
ranks=[0])
|
ranks=[0])
|
||||||
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
self._write_back_data()
|
||||||
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
|
||||||
# a chunk.
|
|
||||||
for group in self.optim.param_groups:
|
|
||||||
for p in group['params']:
|
|
||||||
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
|
||||||
if not is_param_sharded:
|
|
||||||
# We use ZeRO-2 here
|
|
||||||
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
|
||||||
# But we only have updated fp32 param shard here
|
|
||||||
# So we first shard full fp16 param and copy fp32 param shard to it
|
|
||||||
# Then we will gather them
|
|
||||||
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
||||||
# We have to use `copy_payload` instead of `reset_payload`
|
|
||||||
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
|
||||||
|
|
||||||
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
|
||||||
p.colo_attr.sharded_data_tensor.reset_payload(
|
|
||||||
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
|
||||||
|
|
||||||
if not is_param_sharded:
|
|
||||||
# We gather full fp16 param here
|
|
||||||
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
||||||
p.data = p.colo_attr.sharded_data_tensor.payload
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def backward(self, loss: Tensor) -> None:
|
def backward(self, loss: Tensor) -> None:
|
||||||
|
@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
# Because we will judge whether local grad accumulation
|
# Because we will judge whether local grad accumulation
|
||||||
# is enabled by wheter grad is None
|
# is enabled by wheter grad is None
|
||||||
self.optim.zero_grad(set_to_none=True)
|
self.optim.zero_grad(set_to_none=True)
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
p.colo_attr.saved_grad.set_null()
|
||||||
|
|
||||||
def sync_grad(self):
|
def sync_grad(self):
|
||||||
pass
|
pass
|
||||||
|
@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
fp32_shards_used_cuda_margin_mem = 0
|
fp32_shards_used_cuda_margin_mem = 0
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
|
||||||
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
||||||
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
|
||||||
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
||||||
p.colo_attr.offload_grad = False
|
p.colo_attr.offload_grad = False
|
||||||
fp32_shards_used_cuda_margin_mem += shard_mem
|
fp32_shards_used_cuda_margin_mem += shard_mem
|
||||||
|
@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
||||||
def _prepare_grads(self):
|
def _prepare_grads(self):
|
||||||
for group in self.optim.param_groups:
|
for group in self.optim.param_groups:
|
||||||
for p in group['params']:
|
for p in group['params']:
|
||||||
|
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
|
||||||
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
||||||
# If we change p.grad directly
|
# If we change p.grad directly
|
||||||
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
||||||
# We just set p.data = p.colo_attr.saved_grad.payload here
|
# We just set p.data = p.colo_attr.saved_grad.payload here
|
||||||
p.data = p.colo_attr.saved_grad.payload
|
p.data = p.colo_attr.saved_grad.payload
|
||||||
p.grad = p.colo_attr.saved_grad.payload
|
p.grad = p.colo_attr.saved_grad.payload
|
||||||
|
# Set p.data to empty tensor, in case of memory leaking
|
||||||
|
p.colo_attr.remove_torch_payload()
|
||||||
|
|
||||||
|
def _prepare_data(self):
|
||||||
|
# assign master param pointers to p.data.
|
||||||
|
# We will not trigger data copy here.
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
self.master_params[p].trans_state(TensorState.COMPUTE)
|
||||||
|
p.data = self.master_params[p].payload
|
||||||
|
# Now p.data is sharded
|
||||||
|
# So optimizer states are sharded naturally
|
||||||
|
|
||||||
|
def _write_back_data(self):
|
||||||
|
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
||||||
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
||||||
|
# a chunk.
|
||||||
|
for group in self.optim.param_groups:
|
||||||
|
for p in group['params']:
|
||||||
|
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
||||||
|
if not is_param_sharded:
|
||||||
|
# We use ZeRO-2 here
|
||||||
|
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
||||||
|
# But we only have updated fp32 param shard here
|
||||||
|
# So we first shard full fp16 param and copy fp32 param shard to it
|
||||||
|
# Then we will gather them
|
||||||
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
# We have to use `copy_payload` instead of `reset_payload`
|
||||||
|
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
||||||
|
|
||||||
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
||||||
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
||||||
|
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
||||||
|
|
||||||
|
if not is_param_sharded:
|
||||||
|
# We gather full fp16 param here
|
||||||
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
||||||
|
p.data = p.colo_attr.sharded_data_tensor.payload
|
||||||
|
self.master_params[p].trans_state(TensorState.HOLD)
|
||||||
p.colo_attr.saved_grad.set_null()
|
p.colo_attr.saved_grad.set_null()
|
||||||
|
|
|
@ -10,7 +10,6 @@ class ShardedParamV2(object):
|
||||||
|
|
||||||
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
def __init__(self, param: torch.nn.Parameter, rm_torch_payload=False) -> None:
|
||||||
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
self._sharded_data_tensor: ShardedTensor = ShardedTensor(param.data)
|
||||||
self.fp16_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
|
||||||
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
self.saved_grad: StatefulTensor = StatefulTensor(None, TensorState.FREE)
|
||||||
# This attribute must be initialized in ShardedModel
|
# This attribute must be initialized in ShardedModel
|
||||||
self.offload_grad: bool = False
|
self.offload_grad: bool = False
|
||||||
|
@ -57,10 +56,6 @@ class ShardedParamV2(object):
|
||||||
_update_mem_use(self.sharded_data_tensor.payload)
|
_update_mem_use(self.sharded_data_tensor.payload)
|
||||||
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
address_set.add(self.sharded_data_tensor.payload.data_ptr())
|
||||||
|
|
||||||
if not self.fp16_grad.is_null() and self.fp16_grad.data_ptr() not in address_set:
|
|
||||||
_update_mem_use(self.fp16_grad.payload)
|
|
||||||
address_set.add(self.fp16_grad.data_ptr())
|
|
||||||
|
|
||||||
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
if not self.saved_grad.is_null() and self.saved_grad.data_ptr() not in address_set:
|
||||||
_update_mem_use(self.saved_grad.payload)
|
_update_mem_use(self.saved_grad.payload)
|
||||||
address_set.add(self.saved_grad.data_ptr())
|
address_set.add(self.saved_grad.data_ptr())
|
||||||
|
|
|
@ -63,12 +63,6 @@ def _run_shard_param_v2(rank, world_size, port):
|
||||||
# 4 is size of dummy tensor of param.data
|
# 4 is size of dummy tensor of param.data
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
||||||
|
|
||||||
sparam.fp16_grad = StatefulTensor(torch.randn(2, 3).cuda().half())
|
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
|
||||||
assert cpu_mem_use == 2 * 3 * 4 * 2 + 4
|
|
||||||
assert cuda_mem_use == 2 * 3 * 2
|
|
||||||
|
|
||||||
sparam.fp16_grad = StatefulTensor(None)
|
|
||||||
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
sparam.saved_grad = StatefulTensor(torch.randn(2, 3))
|
||||||
sparam.remove_torch_payload()
|
sparam.remove_torch_payload()
|
||||||
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
cuda_mem_use, cpu_mem_use = sparam.get_memory_usage()
|
||||||
|
|
Loading…
Reference in New Issue