perf: use async copy to accelerate memcpy

pull/5817/head
Wenhao Chen 2024-03-28 15:02:32 +08:00 committed by アマデウス
parent a53c8c1ade
commit 1aaa453706
1 changed files with 66 additions and 22 deletions

View File

@ -21,7 +21,14 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
from colossalai.interface import OptimizerWrapper
from colossalai.logging import get_dist_logger
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
from ._utils import (
DataPrefetcher,
calculate_global_norm_from_list,
flatten,
has_inf_or_nan,
release_param_grad,
sync_tensor,
)
from .bookkeeping import BucketStore, GradientStore, ParameterStore
@ -437,10 +444,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
if len(grads) > 0:
real_working_params[group_id].append(working_param)
grad = grads[grad_index]
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(splited_param.dtype).to(splited_param.device)
splited_param.grad = grad
grad_partition_groups.append(grad)
real_master_params[group_id].append(splited_param)
@ -458,27 +461,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
# update the parameters
self.optim.step()
# release the grad
grad_partition_groups = []
for group_id in range(self.num_param_groups):
release_param_grad(self._master_param_groups_of_current_rank[group_id])
# update working partition updated by the current rank
device = get_accelerator().get_current_device()
for group_id in range(self.num_param_groups):
master_working_param = self.optim.param_groups[group_id]["params"]
for idx, splited_param in enumerate(master_working_param):
working_param = real_working_params[group_id][idx]
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
def load_grad(num: int):
"""copy grads to the same device and dtype as the master weights"""
for i in range(num):
grad = grad_partition_groups.pop(0)
# no need to copy fp32 grad if master_weights is False
if self._master_weights:
grad = grad.to(real_master_params[group_id][i].dtype).to(
real_master_params[group_id][i].device, non_blocking=True
)
yield grad
def load_param(num: int):
"""copy params back to the accelerator"""
for i in range(num):
splited_param = real_master_params[group_id][i].to(device, non_blocking=True).to(self._dtype)
yield splited_param
"""
grad (device) --> grad (host or device) --> optim.step() --> param (host or device) --> param (device)
"""
grad_pre_fetcher, param_pre_fetcher = None, None
for idx in range(len(real_master_params[group_id]) + 1):
is_first_step = idx == 0
is_last_step = idx == len(real_master_params[group_id])
if not is_last_step:
# update the parameters
if grad_pre_fetcher is None:
grad_pre_fetcher = DataPrefetcher(load_grad(len(real_master_params[group_id])))
real_master_params[group_id][idx].grad = grad_pre_fetcher.next()
# HACK: torch optim would skip tensor whose grad is None
self.optim.step()
real_master_params[group_id][idx].grad = None
if not is_first_step:
# update working partition updated by the current rank
if param_pre_fetcher is None:
param_pre_fetcher = DataPrefetcher(load_param(len(real_master_params[group_id])))
working_param = real_working_params[group_id][idx - 1]
splited_param = param_pre_fetcher.next()
all_splited_param = [
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
for _ in range(self._world_size)
]
dist.all_gather(all_splited_param, splited_param, group=self.dp_pg)
working_param.data.copy_(
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
)
# release the grad
release_param_grad(self._master_param_groups_of_current_rank[group_id])
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
assert len(grad_partition_groups) == 0, "grad_partition_groups should be empty after step()"
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
r"""
Compute and return the gradient norm for gradient clipping.