Merge pull request #2 from SolenoidWGT/fp32_zero

feat(optim): add support for fp32 zero
pull/155/head
ytxiong 2023-08-03 14:58:12 +08:00 committed by GitHub
commit 53fc50b0e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 78 deletions

View File

@ -87,6 +87,7 @@ class HybridZeroOptimizer(BaseOptimizer):
overlap_broadcast=False,
grad_scal_cfg: Config = None,
zero_cfg: Config = None,
use_fp16: bool = True,
):
# DynamicGradScaler related args
initial_scale = grad_scal_cfg.fp16.initial_scale
@ -104,6 +105,7 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer)
self.use_fp16 = use_fp16
self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
@ -125,7 +127,8 @@ class HybridZeroOptimizer(BaseOptimizer):
self._reduce_bucket_size = reduce_bucket_size
# gradient scaler
self.grad_scaler = DynamicGradScaler(
self.grad_scaler = (
DynamicGradScaler(
initial_scale=initial_scale,
min_scale=min_scale,
growth_factor=growth_factor,
@ -134,6 +137,9 @@ class HybridZeroOptimizer(BaseOptimizer):
hysteresis=hysteresis,
max_scale=max_scale,
)
if self.use_fp16
else None
)
self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device())
# gradient clipping
@ -176,11 +182,14 @@ class HybridZeroOptimizer(BaseOptimizer):
for param in params:
self._param_store.set_param_to_rank(param, rank)
# flatten the reordered tensors
# move to cpu to make room to create the flat tensor
# Even for fp32 training, we will still flattend the tensor,
# which will not increase the use of GPU memory,
# and can improve the efficiency of broadcasting.
for param in group_params:
param.data = param.data.cpu()
# flatten the reordered tensors
for rank in range(self._zero_world_size):
# No flat fp16 buffer is allocated if the process has no parameters.
if rank not in self.param_group_no_params_ranks[group_id]:
@ -194,6 +203,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# create a copy of fp32 weights of the parameters for which this rank is responsible
# No flat fp32 buffer is allocated if the process has no parameters.
if self.param_group_has_params[group_id]:
if self.use_fp16:
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
self._zero_local_rank, group_id
)
@ -207,6 +217,11 @@ class HybridZeroOptimizer(BaseOptimizer):
# so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank
param_group["params"] = [fp32_flat_current_rank]
else:
# use fp32
param_group["params"] = self._param_store.get_fp16_params_by_rank_group(
self._zero_local_rank, group_id
)
# set reduction state
for param in self._fp16_param_groups[group_id]:
@ -243,6 +258,9 @@ class HybridZeroOptimizer(BaseOptimizer):
@property
def loss_scale(self):
if self.grad_scaler is None:
return 1
else:
return self.grad_scaler.scale
@property
@ -533,6 +551,7 @@ class HybridZeroOptimizer(BaseOptimizer):
norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup
if self.grad_scaler:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs
if found_inf:
@ -552,6 +571,7 @@ class HybridZeroOptimizer(BaseOptimizer):
continue
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
if self.use_fp16:
# create flat gradient for the flat fp32 params
fp16_avg_grads = gradients
flat_fp16_avg_grads = flatten(fp16_avg_grads)
@ -567,6 +587,14 @@ class HybridZeroOptimizer(BaseOptimizer):
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
else:
assert len(gradients) == len(self.optim.param_groups[group_id]["params"]), (
len(gradients),
len(self.optim.param_groups[group_id]["params"]),
)
for g, p in zip(gradients, self.optim.param_groups[group_id]["params"]):
p.grad = g
self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = []
@ -576,6 +604,7 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm = sum(norm_groups) ** 0.5
# the following operations are performed only on the rank to which parameters are assigned.
if self.use_fp16:
if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
@ -588,6 +617,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if self.has_params:
self.optim.step()
# release the fp32 grad
if self.use_fp16:
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
# update fp16 partition updated by the current rank
for group_id in range(len(self._fp16_param_groups)):
@ -614,8 +644,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# The following operations are performed only on the rank to which parameters are assigned.
if rank not in self.param_group_no_params_ranks[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast(
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
@ -667,11 +695,13 @@ class HybridZeroOptimizer(BaseOptimizer):
def state_dict(self):
states = {}
grad_scaler = self.grad_scaler.state_dict()
states["grad_scaler"] = grad_scaler
optim_states = self.optim.state_dict()
states["base_optim_states"] = optim_states
if self.use_fp16:
grad_scaler = self.grad_scaler.state_dict()
states["grad_scaler"] = grad_scaler
flat_fp32_weights = {}
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
@ -684,11 +714,13 @@ class HybridZeroOptimizer(BaseOptimizer):
def load_state_dict(self, states):
# TODO: Need to take into account the change in the number of DP.
optim_states = states["base_optim_states"]
self.optim.load_state_dict(optim_states)
if self.use_fp16:
assert "grad_scaler" in states, "Not found grad_scaler state!"
grad_scaler = states["grad_scaler"]
self.grad_scaler.load_state_dict(grad_scaler)
optim_states = states["base_optim_states"]
self.optim.load_state_dict(optim_states)
# load fp32 model weight.
flat_fp32_weights = states["flat_fp32_weights"]

View File

@ -282,7 +282,10 @@ def initialize_optimizer(model: nn.Module):
)
optimizer = HybridZeroOptimizer(
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
use_fp16= gpc.config.model.dtype is torch.float32,
)
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)