Merge pull request #2 from SolenoidWGT/fp32_zero

feat(optim): add support for fp32 zero
pull/155/head
ytxiong 2023-08-03 14:58:12 +08:00 committed by GitHub
commit 53fc50b0e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 113 additions and 78 deletions

View File

@ -87,6 +87,7 @@ class HybridZeroOptimizer(BaseOptimizer):
overlap_broadcast=False, overlap_broadcast=False,
grad_scal_cfg: Config = None, grad_scal_cfg: Config = None,
zero_cfg: Config = None, zero_cfg: Config = None,
use_fp16: bool = True,
): ):
# DynamicGradScaler related args # DynamicGradScaler related args
initial_scale = grad_scal_cfg.fp16.initial_scale initial_scale = grad_scal_cfg.fp16.initial_scale
@ -104,6 +105,7 @@ class HybridZeroOptimizer(BaseOptimizer):
super().__init__(optim=optimizer) super().__init__(optim=optimizer)
self.use_fp16 = use_fp16
self._dtype = self.optim.param_groups[0]["params"][0].dtype self._dtype = self.optim.param_groups[0]["params"][0].dtype
self._cpu_offload = cpu_offload self._cpu_offload = cpu_offload
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1) self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
@ -125,14 +127,18 @@ class HybridZeroOptimizer(BaseOptimizer):
self._reduce_bucket_size = reduce_bucket_size self._reduce_bucket_size = reduce_bucket_size
# gradient scaler # gradient scaler
self.grad_scaler = DynamicGradScaler( self.grad_scaler = (
initial_scale=initial_scale, DynamicGradScaler(
min_scale=min_scale, initial_scale=initial_scale,
growth_factor=growth_factor, min_scale=min_scale,
backoff_factor=backoff_factor, growth_factor=growth_factor,
growth_interval=growth_interval, backoff_factor=backoff_factor,
hysteresis=hysteresis, growth_interval=growth_interval,
max_scale=max_scale, hysteresis=hysteresis,
max_scale=max_scale,
)
if self.use_fp16
else None
) )
self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device()) self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device())
@ -176,11 +182,14 @@ class HybridZeroOptimizer(BaseOptimizer):
for param in params: for param in params:
self._param_store.set_param_to_rank(param, rank) self._param_store.set_param_to_rank(param, rank)
# flatten the reordered tensors
# move to cpu to make room to create the flat tensor # move to cpu to make room to create the flat tensor
# Even for fp32 training, we will still flattend the tensor,
# which will not increase the use of GPU memory,
# and can improve the efficiency of broadcasting.
for param in group_params: for param in group_params:
param.data = param.data.cpu() param.data = param.data.cpu()
# flatten the reordered tensors
for rank in range(self._zero_world_size): for rank in range(self._zero_world_size):
# No flat fp16 buffer is allocated if the process has no parameters. # No flat fp16 buffer is allocated if the process has no parameters.
if rank not in self.param_group_no_params_ranks[group_id]: if rank not in self.param_group_no_params_ranks[group_id]:
@ -194,19 +203,25 @@ class HybridZeroOptimizer(BaseOptimizer):
# create a copy of fp32 weights of the parameters for which this rank is responsible # create a copy of fp32 weights of the parameters for which this rank is responsible
# No flat fp32 buffer is allocated if the process has no parameters. # No flat fp32 buffer is allocated if the process has no parameters.
if self.param_group_has_params[group_id]: if self.param_group_has_params[group_id]:
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group( if self.use_fp16:
self._zero_local_rank, group_id fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
) self._zero_local_rank, group_id
fp32_flat_current_rank = fp16_flat_current_rank.float() )
device = "cpu" if self._cpu_offload else get_current_device() fp32_flat_current_rank = fp16_flat_current_rank.float()
fp32_flat_current_rank = fp32_flat_current_rank.to(device) device = "cpu" if self._cpu_offload else get_current_device()
fp32_flat_current_rank.requires_grad = True fp32_flat_current_rank = fp32_flat_current_rank.to(device)
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank fp32_flat_current_rank.requires_grad = True
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
# need to replace the params in the `params` field in the optimizer # need to replace the params in the `params` field in the optimizer
# so that when the optimizer calls step(), it only updates the tensors # so that when the optimizer calls step(), it only updates the tensors
# managed by this data parallel rank # managed by this data parallel rank
param_group["params"] = [fp32_flat_current_rank] param_group["params"] = [fp32_flat_current_rank]
else:
# use fp32
param_group["params"] = self._param_store.get_fp16_params_by_rank_group(
self._zero_local_rank, group_id
)
# set reduction state # set reduction state
for param in self._fp16_param_groups[group_id]: for param in self._fp16_param_groups[group_id]:
@ -243,7 +258,10 @@ class HybridZeroOptimizer(BaseOptimizer):
@property @property
def loss_scale(self): def loss_scale(self):
return self.grad_scaler.scale if self.grad_scaler is None:
return 1
else:
return self.grad_scaler.scale
@property @property
def num_param_groups(self): def num_param_groups(self):
@ -533,7 +551,8 @@ class HybridZeroOptimizer(BaseOptimizer):
norm_groups.append(norm_group) norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup loss_scale = float(self.loss_scale.item()) # backup
self.grad_scaler.update(found_inf) if self.grad_scaler:
self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs # update loss scale if overflow occurs
if found_inf: if found_inf:
if gpc.is_rank_for_log(): if gpc.is_rank_for_log():
@ -552,21 +571,30 @@ class HybridZeroOptimizer(BaseOptimizer):
continue continue
gradients = self._grad_store.get_averaged_gradients_by_group(group_id) gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
# create flat gradient for the flat fp32 params if self.use_fp16:
fp16_avg_grads = gradients # create flat gradient for the flat fp32 params
flat_fp16_avg_grads = flatten(fp16_avg_grads) fp16_avg_grads = gradients
flat_fp16_avg_grads = flatten(fp16_avg_grads)
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert ( assert (
param_shape == flat_fp32_avg_grads.shape param_shape == flat_fp32_avg_grads.shape
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}" ), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
else:
assert len(gradients) == len(self.optim.param_groups[group_id]["params"]), (
len(gradients),
len(self.optim.param_groups[group_id]["params"]),
)
for g, p in zip(gradients, self.optim.param_groups[group_id]["params"]):
p.grad = g
single_grad_partition_groups.append(flat_fp32_avg_grads)
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
self._grad_store._averaged_gradients[group_id] = [] self._grad_store._averaged_gradients[group_id] = []
@ -576,8 +604,9 @@ class HybridZeroOptimizer(BaseOptimizer):
global_norm = sum(norm_groups) ** 0.5 global_norm = sum(norm_groups) ** 0.5
# the following operations are performed only on the rank to which parameters are assigned. # the following operations are performed only on the rank to which parameters are assigned.
if len(single_grad_partition_groups) != 0: if self.use_fp16:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale) if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
timer("cal_norm").stop() timer("cal_norm").stop()
# update the parameters # update the parameters
@ -588,15 +617,16 @@ class HybridZeroOptimizer(BaseOptimizer):
if self.has_params: if self.has_params:
self.optim.step() self.optim.step()
# release the fp32 grad # release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values()) if self.use_fp16:
# update fp16 partition updated by the current rank release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
for group_id in range(len(self._fp16_param_groups)): # update fp16 partition updated by the current rank
if self.param_group_has_params[group_id]: for group_id in range(len(self._fp16_param_groups)):
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( if self.param_group_has_params[group_id]:
rank=self._zero_local_rank, group_id=group_id fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
) rank=self._zero_local_rank, group_id=group_id
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] )
fp16_param.data.copy_(fp32_param) fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
# TODO: support broadcast overlap # TODO: support broadcast overlap
self.broadcast_params(overlap=False) self.broadcast_params(overlap=False)
@ -614,8 +644,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# The following operations are performed only on the rank to which parameters are assigned. # The following operations are performed only on the rank to which parameters are assigned.
if rank not in self.param_group_no_params_ranks[group_id]: if rank not in self.param_group_no_params_ranks[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id) fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
# assert grank == rank, f"{grank} == {rank}"
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank] g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
handle = dist.broadcast( handle = dist.broadcast(
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
@ -667,48 +695,52 @@ class HybridZeroOptimizer(BaseOptimizer):
def state_dict(self): def state_dict(self):
states = {} states = {}
grad_scaler = self.grad_scaler.state_dict()
states["grad_scaler"] = grad_scaler
optim_states = self.optim.state_dict() optim_states = self.optim.state_dict()
states["base_optim_states"] = optim_states states["base_optim_states"] = optim_states
flat_fp32_weights = {} if self.use_fp16:
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): grad_scaler = self.grad_scaler.state_dict()
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: states["grad_scaler"] = grad_scaler
assert param.grad is None
flat_fp32_weights[group_id] = param flat_fp32_weights = {}
states["flat_fp32_weights"] = flat_fp32_weights for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
assert param.grad is None
flat_fp32_weights[group_id] = param
states["flat_fp32_weights"] = flat_fp32_weights
states["zero_devide_optim_plan"] = self.params_per_rank_id_dict states["zero_devide_optim_plan"] = self.params_per_rank_id_dict
return states return states
def load_state_dict(self, states): def load_state_dict(self, states):
# TODO: Need to take into account the change in the number of DP. # TODO: Need to take into account the change in the number of DP.
assert "grad_scaler" in states, "Not found grad_scaler state!"
grad_scaler = states["grad_scaler"]
self.grad_scaler.load_state_dict(grad_scaler)
optim_states = states["base_optim_states"] optim_states = states["base_optim_states"]
self.optim.load_state_dict(optim_states) self.optim.load_state_dict(optim_states)
# load fp32 model weight. if self.use_fp16:
flat_fp32_weights = states["flat_fp32_weights"] assert "grad_scaler" in states, "Not found grad_scaler state!"
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank) grad_scaler = states["grad_scaler"]
for group_id, param in flat_fp32_weights.items(): self.grad_scaler.load_state_dict(grad_scaler)
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
assert (
self_param.shape == param.shape
), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}"
self_param.data.copy_(param.data)
# Load the fp16 model weights. # load fp32 model weight.
for group_id in range(len(self._fp16_param_groups)): flat_fp32_weights = states["flat_fp32_weights"]
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]: assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( for group_id, param in flat_fp32_weights.items():
rank=self._zero_local_rank, group_id=group_id if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
) self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] assert (
fp16_param.data.copy_(fp32_param) self_param.shape == param.shape
), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}"
self_param.data.copy_(param.data)
# Load the fp16 model weights.
for group_id in range(len(self._fp16_param_groups)):
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
rank=self._zero_local_rank, group_id=group_id
)
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
fp16_param.data.copy_(fp32_param)
if "zero_devide_optim_plan" in states: if "zero_devide_optim_plan" in states:
self.params_per_rank_id_dict = states["zero_devide_optim_plan"] self.params_per_rank_id_dict = states["zero_devide_optim_plan"]

View File

@ -282,7 +282,10 @@ def initialize_optimizer(model: nn.Module):
) )
optimizer = HybridZeroOptimizer( optimizer = HybridZeroOptimizer(
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer naive_optimizer,
grad_scal_cfg=gpc.config.grad_scaler,
zero_cfg=gpc.config.hybrid_zero_optimizer,
use_fp16= gpc.config.model.dtype is torch.float32,
) )
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler) beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)