mirror of https://github.com/hpcaitech/ColossalAI
perf: use async copy to accelerate memcpy
parent
a53c8c1ade
commit
1aaa453706
|
@ -21,7 +21,14 @@ from colossalai.amp.naive_amp.mixed_precision_mixin import (
|
||||||
from colossalai.interface import OptimizerWrapper
|
from colossalai.interface import OptimizerWrapper
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
|
||||||
from ._utils import calculate_global_norm_from_list, flatten, has_inf_or_nan, release_param_grad, sync_tensor
|
from ._utils import (
|
||||||
|
DataPrefetcher,
|
||||||
|
calculate_global_norm_from_list,
|
||||||
|
flatten,
|
||||||
|
has_inf_or_nan,
|
||||||
|
release_param_grad,
|
||||||
|
sync_tensor,
|
||||||
|
)
|
||||||
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
from .bookkeeping import BucketStore, GradientStore, ParameterStore
|
||||||
|
|
||||||
|
|
||||||
|
@ -437,10 +444,6 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
if len(grads) > 0:
|
if len(grads) > 0:
|
||||||
real_working_params[group_id].append(working_param)
|
real_working_params[group_id].append(working_param)
|
||||||
grad = grads[grad_index]
|
grad = grads[grad_index]
|
||||||
# no need to copy fp32 grad if master_weights is False
|
|
||||||
if self._master_weights:
|
|
||||||
grad = grad.to(splited_param.dtype).to(splited_param.device)
|
|
||||||
splited_param.grad = grad
|
|
||||||
grad_partition_groups.append(grad)
|
grad_partition_groups.append(grad)
|
||||||
real_master_params[group_id].append(splited_param)
|
real_master_params[group_id].append(splited_param)
|
||||||
|
|
||||||
|
@ -458,27 +461,68 @@ class LowLevelZeroOptimizer(OptimizerWrapper):
|
||||||
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
global_norm = calculate_global_norm_from_list(norm_list=norm_groups)
|
||||||
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
self._unscale_and_clip_grads(grad_partition_groups, global_norm)
|
||||||
|
|
||||||
# update the parameters
|
|
||||||
self.optim.step()
|
|
||||||
|
|
||||||
# release the grad
|
|
||||||
grad_partition_groups = []
|
|
||||||
for group_id in range(self.num_param_groups):
|
|
||||||
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
|
||||||
|
|
||||||
# update working partition updated by the current rank
|
|
||||||
device = get_accelerator().get_current_device()
|
device = get_accelerator().get_current_device()
|
||||||
for group_id in range(self.num_param_groups):
|
for group_id in range(self.num_param_groups):
|
||||||
master_working_param = self.optim.param_groups[group_id]["params"]
|
|
||||||
for idx, splited_param in enumerate(master_working_param):
|
def load_grad(num: int):
|
||||||
working_param = real_working_params[group_id][idx]
|
"""copy grads to the same device and dtype as the master weights"""
|
||||||
all_splited_param = [
|
for i in range(num):
|
||||||
torch.zeros(splited_param.shape, device=device, dtype=self._dtype) for _ in range(self._world_size)
|
grad = grad_partition_groups.pop(0)
|
||||||
]
|
# no need to copy fp32 grad if master_weights is False
|
||||||
dist.all_gather(all_splited_param, splited_param.to(device).to(self._dtype), group=self.dp_pg)
|
if self._master_weights:
|
||||||
working_param.data.copy_(flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param))
|
grad = grad.to(real_master_params[group_id][i].dtype).to(
|
||||||
|
real_master_params[group_id][i].device, non_blocking=True
|
||||||
|
)
|
||||||
|
yield grad
|
||||||
|
|
||||||
|
def load_param(num: int):
|
||||||
|
"""copy params back to the accelerator"""
|
||||||
|
for i in range(num):
|
||||||
|
splited_param = real_master_params[group_id][i].to(device, non_blocking=True).to(self._dtype)
|
||||||
|
yield splited_param
|
||||||
|
|
||||||
|
"""
|
||||||
|
grad (device) --> grad (host or device) --> optim.step() --> param (host or device) --> param (device)
|
||||||
|
"""
|
||||||
|
grad_pre_fetcher, param_pre_fetcher = None, None
|
||||||
|
for idx in range(len(real_master_params[group_id]) + 1):
|
||||||
|
is_first_step = idx == 0
|
||||||
|
is_last_step = idx == len(real_master_params[group_id])
|
||||||
|
|
||||||
|
if not is_last_step:
|
||||||
|
# update the parameters
|
||||||
|
if grad_pre_fetcher is None:
|
||||||
|
grad_pre_fetcher = DataPrefetcher(load_grad(len(real_master_params[group_id])))
|
||||||
|
|
||||||
|
real_master_params[group_id][idx].grad = grad_pre_fetcher.next()
|
||||||
|
# HACK: torch optim would skip tensor whose grad is None
|
||||||
|
self.optim.step()
|
||||||
|
real_master_params[group_id][idx].grad = None
|
||||||
|
|
||||||
|
if not is_first_step:
|
||||||
|
# update working partition updated by the current rank
|
||||||
|
if param_pre_fetcher is None:
|
||||||
|
param_pre_fetcher = DataPrefetcher(load_param(len(real_master_params[group_id])))
|
||||||
|
|
||||||
|
working_param = real_working_params[group_id][idx - 1]
|
||||||
|
splited_param = param_pre_fetcher.next()
|
||||||
|
|
||||||
|
all_splited_param = [
|
||||||
|
torch.zeros(splited_param.shape, device=device, dtype=self._dtype)
|
||||||
|
for _ in range(self._world_size)
|
||||||
|
]
|
||||||
|
dist.all_gather(all_splited_param, splited_param, group=self.dp_pg)
|
||||||
|
|
||||||
|
working_param.data.copy_(
|
||||||
|
flatten(all_splited_param)[: working_param.numel()].reshape_as(working_param)
|
||||||
|
)
|
||||||
|
|
||||||
|
# release the grad
|
||||||
|
release_param_grad(self._master_param_groups_of_current_rank[group_id])
|
||||||
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
self.optim.param_groups[group_id]["params"] = self._master_param_groups_of_current_rank[group_id]
|
||||||
|
|
||||||
|
assert len(grad_partition_groups) == 0, "grad_partition_groups should be empty after step()"
|
||||||
|
|
||||||
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
def _compute_grad_norm(self, gradients: List[Tensor], norm_type: int = 2) -> float:
|
||||||
r"""
|
r"""
|
||||||
Compute and return the gradient norm for gradient clipping.
|
Compute and return the gradient norm for gradient clipping.
|
||||||
|
|
Loading…
Reference in New Issue