|
|
|
@ -12,11 +12,12 @@ from colossalai.logging import get_dist_logger
|
|
|
|
|
from colossalai.nn.optimizer import ColossalaiOptimizer
|
|
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
|
|
|
|
GLOBAL_MODEL_DATA_TRACER
|
|
|
|
|
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move, colo_model_tensor_clone,
|
|
|
|
|
from colossalai.utils.memory_utils.utils import (colo_model_data_tensor_move_inline, colo_model_tensor_clone,
|
|
|
|
|
colo_tensor_mem_usage)
|
|
|
|
|
from colossalai.zero.sharded_model import ShardedModelV2
|
|
|
|
|
from colossalai.zero.sharded_model._utils import cast_tensor_to_fp32
|
|
|
|
|
from colossalai.zero.sharded_optim._utils import has_inf_or_nan
|
|
|
|
|
from colossalai.zero.sharded_param.tensorful_state import (StatefulTensor, TensorState)
|
|
|
|
|
from torch import Tensor
|
|
|
|
|
from torch.distributed import ProcessGroup
|
|
|
|
|
from torch.nn.parameter import Parameter
|
|
|
|
@ -112,7 +113,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self._logger = get_dist_logger("ShardedOptimizerV2")
|
|
|
|
|
|
|
|
|
|
# Store fp32 param shards
|
|
|
|
|
self.master_params: Dict[Parameter, Tensor] = {}
|
|
|
|
|
self.master_params: Dict[Parameter, StatefulTensor] = {}
|
|
|
|
|
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
@ -123,7 +124,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
# Param is no sharded, which means we use ZeRO-2 here
|
|
|
|
|
# As we only store param shard, we shard it here
|
|
|
|
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
|
self.master_params[p] = cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device)
|
|
|
|
|
self.master_params[p] = StatefulTensor(
|
|
|
|
|
cast_tensor_to_fp32(p.colo_attr.sharded_data_tensor.payload).to(self.device))
|
|
|
|
|
if not is_param_sharded:
|
|
|
|
|
# In this branch, there's no need to shard param
|
|
|
|
|
# So we gather here
|
|
|
|
@ -184,13 +186,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self.zero_grad()
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
# assign master param pointers to p.data.
|
|
|
|
|
# We will not trigger data copy here.
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
p.data = self.master_params[p]
|
|
|
|
|
# Now p.data is sharded
|
|
|
|
|
# So optimizer states are sharded naturally
|
|
|
|
|
self._prepare_data()
|
|
|
|
|
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"Before step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
|
|
|
@ -201,30 +197,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
self._logger.debug(
|
|
|
|
|
f"After step ShardedOptimizerV2 consumes {self.get_memory_usage()[0]/1e6} MB CUDA Memory, {self.get_memory_usage()[1]/1e6} MB CUDA Memory!",
|
|
|
|
|
ranks=[0])
|
|
|
|
|
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
|
|
|
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
|
|
|
|
# a chunk.
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
|
|
|
|
if not is_param_sharded:
|
|
|
|
|
# We use ZeRO-2 here
|
|
|
|
|
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
|
|
|
|
# But we only have updated fp32 param shard here
|
|
|
|
|
# So we first shard full fp16 param and copy fp32 param shard to it
|
|
|
|
|
# Then we will gather them
|
|
|
|
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
|
# We have to use `copy_payload` instead of `reset_payload`
|
|
|
|
|
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
|
|
|
|
|
|
|
|
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
|
|
|
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
|
|
|
|
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
|
|
|
|
|
|
|
|
|
if not is_param_sharded:
|
|
|
|
|
# We gather full fp16 param here
|
|
|
|
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
|
p.data = p.colo_attr.sharded_data_tensor.payload
|
|
|
|
|
self._write_back_data()
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
def backward(self, loss: Tensor) -> None:
|
|
|
|
@ -276,6 +249,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
# Because we will judge whether local grad accumulation
|
|
|
|
|
# is enabled by wheter grad is None
|
|
|
|
|
self.optim.zero_grad(set_to_none=True)
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
p.colo_attr.saved_grad.set_null()
|
|
|
|
|
|
|
|
|
|
def sync_grad(self):
|
|
|
|
|
pass
|
|
|
|
@ -288,9 +264,9 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
fp32_shards_used_cuda_margin_mem = 0
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
shard_mem = self.master_params[p].numel() * self.master_params[p].element_size()
|
|
|
|
|
shard_mem = self.master_params[p].payload.numel() * self.master_params[p].payload.element_size()
|
|
|
|
|
if fp32_shards_used_cuda_margin_mem + shard_mem < fp32_shards_available_cuda_margin_mem:
|
|
|
|
|
self.master_params[p] = self.master_params[p].to(torch.cuda.current_device())
|
|
|
|
|
colo_model_data_tensor_move_inline(self.master_params[p], torch.cuda.current_device())
|
|
|
|
|
p.grad.data = p.grad.data.to(torch.cuda.current_device())
|
|
|
|
|
p.colo_attr.offload_grad = False
|
|
|
|
|
fp32_shards_used_cuda_margin_mem += shard_mem
|
|
|
|
@ -298,10 +274,50 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|
|
|
|
def _prepare_grads(self):
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
p.colo_attr.saved_grad.trans_state(TensorState.COMPUTE)
|
|
|
|
|
# FIXME(ver217): p.data here is an empty tensor on CUDA and has no useful infomation
|
|
|
|
|
# If we change p.grad directly
|
|
|
|
|
# it may raise error because of different shape/dtype/device of p.data and p.grad
|
|
|
|
|
# We just set p.data = p.colo_attr.saved_grad.payload here
|
|
|
|
|
p.data = p.colo_attr.saved_grad.payload
|
|
|
|
|
p.grad = p.colo_attr.saved_grad.payload
|
|
|
|
|
# Set p.data to empty tensor, in case of memory leaking
|
|
|
|
|
p.colo_attr.remove_torch_payload()
|
|
|
|
|
|
|
|
|
|
def _prepare_data(self):
|
|
|
|
|
# assign master param pointers to p.data.
|
|
|
|
|
# We will not trigger data copy here.
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
self.master_params[p].trans_state(TensorState.COMPUTE)
|
|
|
|
|
p.data = self.master_params[p].payload
|
|
|
|
|
# Now p.data is sharded
|
|
|
|
|
# So optimizer states are sharded naturally
|
|
|
|
|
|
|
|
|
|
def _write_back_data(self):
|
|
|
|
|
# Copy master param data (fp32) to payload of colo_attr (fp16)
|
|
|
|
|
# TODO() improve efficiency by gathering tensors into a chunk and transfering
|
|
|
|
|
# a chunk.
|
|
|
|
|
for group in self.optim.param_groups:
|
|
|
|
|
for p in group['params']:
|
|
|
|
|
is_param_sharded = p.colo_attr.sharded_data_tensor.is_sharded
|
|
|
|
|
if not is_param_sharded:
|
|
|
|
|
# We use ZeRO-2 here
|
|
|
|
|
# The `p.colo_attr.sharded_data_tensor` saves full fp16 param
|
|
|
|
|
# But we only have updated fp32 param shard here
|
|
|
|
|
# So we first shard full fp16 param and copy fp32 param shard to it
|
|
|
|
|
# Then we will gather them
|
|
|
|
|
self.shard_strategy.shard([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
|
# We have to use `copy_payload` instead of `reset_payload`
|
|
|
|
|
# Since p.data is fp32 and p.colo_attr.sharded_data_tensor is fp16
|
|
|
|
|
|
|
|
|
|
# TODO() optimize this line CPU (fp32) -> GPU (fp16)
|
|
|
|
|
p.colo_attr.sharded_data_tensor.reset_payload(
|
|
|
|
|
colo_model_tensor_clone(p.half(), torch.cuda.current_device()))
|
|
|
|
|
|
|
|
|
|
if not is_param_sharded:
|
|
|
|
|
# We gather full fp16 param here
|
|
|
|
|
self.shard_strategy.gather([p.colo_attr.sharded_data_tensor], self.dp_process_group)
|
|
|
|
|
p.data = p.colo_attr.sharded_data_tensor.payload
|
|
|
|
|
self.master_params[p].trans_state(TensorState.HOLD)
|
|
|
|
|
p.colo_attr.saved_grad.set_null()
|
|
|
|
|