mirror of https://github.com/hpcaitech/ColossalAI
perf: use async copy to accelerate memcpy
parent
a53c8c1ade
commit
1aaa453706
|
@ -21,7 +21,14 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
|||
from colossalai.interface import OptimizerWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
||||
from ._utils import (
|
||||
DataPrefetcher,
|
||||
calculate_global_norm_from_list,
|
||||
flatten,
|
||||
has_inf_or_nan,
|
||||
release_param_grad,
|
||||
sync_tensor,
|
||||
)
|
||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||
|
||||
|
||||
|
@ -437,10 +444,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
if len(grads) > 0:
|
||||
real_working_params[group_id].append(working_param)
|
||||
grad = grads[grad_index]
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
if self._master_weights:
|
||||
grad = grad.to(splited_param.dtype).to(splited_param.device)
|
||||
splited_param.grad = grad
|
||||
grad_partition_groups.append(grad)
|
||||
real_master_params[group_id].append(splited_param)
|
||||
|
||||
|
@ -458,27 +461,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
|||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||
|
||||
# update the parameters
|
||||
self.optim.step()
|
||||
|
||||
# release the grad
|
||||
grad_partition_groups = []
|
||||
for group_id in range(self.num_param_groups):
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
|
||||
# update working partition updated by the current rank
|
||||
device = get_accelerator().get_current_device()
|
||||
for group_id in range(self.num_param_groups):
|
||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
||||
for idx, splited_param in enumerate(master_working_param):
|
||||
working_param = real_working_params[group_id][idx]
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
|
||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
||||
|
||||
def load_grad(num: int):
|
||||
"""copy grads to the same device and dtype as the master weights"""
|
||||
for i in range(num):
|
||||
grad = grad_partition_groups.pop(0)
|
||||
# no need to copy fp32 grad if master_weights is False
|
||||
if self._master_weights:
|
||||
grad = grad.to(real_master_params[group_id][i].dtype).to(
|
||||
real_master_params[group_id][i].device, non_blocking=True
|
||||
)
|
||||
yield grad
|
||||
|
||||
def load_param(num: int):
|
||||
"""copy params back to the accelerator"""
|
||||
for i in range(num):
|
||||
splited_param = real_master_params[group_id][i].to(device, non_blocking=True).to(self._dtype)
|
||||
yield splited_param
|
||||
|
||||
"""
|
||||
grad (device) --> grad (host or device) --> optim.step() --> param (host or device) --> param (device)
|
||||
"""
|
||||
grad_pre_fetcher, param_pre_fetcher = None, None
|
||||
for idx in range(len(real_master_params[group_id]) + 1):
|
||||
is_first_step = idx == 0
|
||||
is_last_step = idx == len(real_master_params[group_id])
|
||||
|
||||
if not is_last_step:
|
||||
# update the parameters
|
||||
if grad_pre_fetcher is None:
|
||||
grad_pre_fetcher = DataPrefetcher(load_grad(len(real_master_params[group_id])))
|
||||
|
||||
real_master_params[group_id][idx].grad = grad_pre_fetcher.next()
|
||||
# HACK: torch optim would skip tensor whose grad is None
|
||||
self.optim.step()
|
||||
real_master_params[group_id][idx].grad = None
|
||||
|
||||
if not is_first_step:
|
||||
# update working partition updated by the current rank
|
||||
if param_pre_fetcher is None:
|
||||
param_pre_fetcher = DataPrefetcher(load_param(len(real_master_params[group_id])))
|
||||
|
||||
working_param = real_working_params[group_id][idx - 1]
|
||||
splited_param = param_pre_fetcher.next()
|
||||
|
||||
all_splited_param = [
|
||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||
for _ in range(self._world_size)
|
||||
]
|
||||
dist.all_gather(all_splited_param, splited_param, group=self.dp_pg)
|
||||
|
||||
working_param.data.copy_(
|
||||
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
|
||||
)
|
||||
|
||||
# release the grad
|
||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||
|
||||
assert len(grad_partition_groups) == 0, "grad_partition_groups should be empty after step()"
|
||||
|
||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||
r"""
|
||||
Compute and return the gradient norm for gradient clipping.
|
||||
|
|
Loading…
Reference in New Issue