mirror of https://github.com/InternLM/InternLM
Merge pull request #2 from SolenoidWGT/fp32_zero
feat(optim): add support for fp32 zeropull/155/head
commit
53fc50b0e5
|
@ -87,6 +87,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
overlap_broadcast=False,
|
overlap_broadcast=False,
|
||||||
grad_scal_cfg: Config = None,
|
grad_scal_cfg: Config = None,
|
||||||
zero_cfg: Config = None,
|
zero_cfg: Config = None,
|
||||||
|
use_fp16: bool = True,
|
||||||
):
|
):
|
||||||
# DynamicGradScaler related args
|
# DynamicGradScaler related args
|
||||||
initial_scale = grad_scal_cfg.fp16.initial_scale
|
initial_scale = grad_scal_cfg.fp16.initial_scale
|
||||||
|
@ -104,6 +105,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
super().__init__(optim=optimizer)
|
super().__init__(optim=optimizer)
|
||||||
|
|
||||||
|
self.use_fp16 = use_fp16
|
||||||
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
self._dtype = self.optim.param_groups[0]["params"][0].dtype
|
||||||
self._cpu_offload = cpu_offload
|
self._cpu_offload = cpu_offload
|
||||||
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
self._zero_local_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
||||||
|
@ -125,14 +127,18 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
self._reduce_bucket_size = reduce_bucket_size
|
self._reduce_bucket_size = reduce_bucket_size
|
||||||
|
|
||||||
# gradient scaler
|
# gradient scaler
|
||||||
self.grad_scaler = DynamicGradScaler(
|
self.grad_scaler = (
|
||||||
initial_scale=initial_scale,
|
DynamicGradScaler(
|
||||||
min_scale=min_scale,
|
initial_scale=initial_scale,
|
||||||
growth_factor=growth_factor,
|
min_scale=min_scale,
|
||||||
backoff_factor=backoff_factor,
|
growth_factor=growth_factor,
|
||||||
growth_interval=growth_interval,
|
backoff_factor=backoff_factor,
|
||||||
hysteresis=hysteresis,
|
growth_interval=growth_interval,
|
||||||
max_scale=max_scale,
|
hysteresis=hysteresis,
|
||||||
|
max_scale=max_scale,
|
||||||
|
)
|
||||||
|
if self.use_fp16
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device())
|
self._found_overflow = torch.cuda.FloatTensor([0], device=get_current_device())
|
||||||
|
|
||||||
|
@ -176,11 +182,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
for param in params:
|
for param in params:
|
||||||
self._param_store.set_param_to_rank(param, rank)
|
self._param_store.set_param_to_rank(param, rank)
|
||||||
|
|
||||||
|
# flatten the reordered tensors
|
||||||
# move to cpu to make room to create the flat tensor
|
# move to cpu to make room to create the flat tensor
|
||||||
|
# Even for fp32 training, we will still flattend the tensor,
|
||||||
|
# which will not increase the use of GPU memory,
|
||||||
|
# and can improve the efficiency of broadcasting.
|
||||||
for param in group_params:
|
for param in group_params:
|
||||||
param.data = param.data.cpu()
|
param.data = param.data.cpu()
|
||||||
|
|
||||||
# flatten the reordered tensors
|
|
||||||
for rank in range(self._zero_world_size):
|
for rank in range(self._zero_world_size):
|
||||||
# No flat fp16 buffer is allocated if the process has no parameters.
|
# No flat fp16 buffer is allocated if the process has no parameters.
|
||||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||||
|
@ -194,19 +203,25 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
# create a copy of fp32 weights of the parameters for which this rank is responsible
|
||||||
# No flat fp32 buffer is allocated if the process has no parameters.
|
# No flat fp32 buffer is allocated if the process has no parameters.
|
||||||
if self.param_group_has_params[group_id]:
|
if self.param_group_has_params[group_id]:
|
||||||
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
if self.use_fp16:
|
||||||
self._zero_local_rank, group_id
|
fp16_flat_current_rank = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
)
|
self._zero_local_rank, group_id
|
||||||
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
)
|
||||||
device = "cpu" if self._cpu_offload else get_current_device()
|
fp32_flat_current_rank = fp16_flat_current_rank.float()
|
||||||
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
device = "cpu" if self._cpu_offload else get_current_device()
|
||||||
fp32_flat_current_rank.requires_grad = True
|
fp32_flat_current_rank = fp32_flat_current_rank.to(device)
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
fp32_flat_current_rank.requires_grad = True
|
||||||
|
self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank
|
||||||
|
|
||||||
# need to replace the params in the `params` field in the optimizer
|
# need to replace the params in the `params` field in the optimizer
|
||||||
# so that when the optimizer calls step(), it only updates the tensors
|
# so that when the optimizer calls step(), it only updates the tensors
|
||||||
# managed by this data parallel rank
|
# managed by this data parallel rank
|
||||||
param_group["params"] = [fp32_flat_current_rank]
|
param_group["params"] = [fp32_flat_current_rank]
|
||||||
|
else:
|
||||||
|
# use fp32
|
||||||
|
param_group["params"] = self._param_store.get_fp16_params_by_rank_group(
|
||||||
|
self._zero_local_rank, group_id
|
||||||
|
)
|
||||||
|
|
||||||
# set reduction state
|
# set reduction state
|
||||||
for param in self._fp16_param_groups[group_id]:
|
for param in self._fp16_param_groups[group_id]:
|
||||||
|
@ -243,7 +258,10 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss_scale(self):
|
def loss_scale(self):
|
||||||
return self.grad_scaler.scale
|
if self.grad_scaler is None:
|
||||||
|
return 1
|
||||||
|
else:
|
||||||
|
return self.grad_scaler.scale
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def num_param_groups(self):
|
def num_param_groups(self):
|
||||||
|
@ -533,7 +551,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
norm_groups.append(norm_group)
|
norm_groups.append(norm_group)
|
||||||
|
|
||||||
loss_scale = float(self.loss_scale.item()) # backup
|
loss_scale = float(self.loss_scale.item()) # backup
|
||||||
self.grad_scaler.update(found_inf)
|
if self.grad_scaler:
|
||||||
|
self.grad_scaler.update(found_inf)
|
||||||
# update loss scale if overflow occurs
|
# update loss scale if overflow occurs
|
||||||
if found_inf:
|
if found_inf:
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
|
@ -552,21 +571,30 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
continue
|
continue
|
||||||
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
|
||||||
|
|
||||||
# create flat gradient for the flat fp32 params
|
if self.use_fp16:
|
||||||
fp16_avg_grads = gradients
|
# create flat gradient for the flat fp32 params
|
||||||
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
fp16_avg_grads = gradients
|
||||||
|
flat_fp16_avg_grads = flatten(fp16_avg_grads)
|
||||||
|
|
||||||
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
|
||||||
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
|
||||||
|
|
||||||
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
|
||||||
assert (
|
assert (
|
||||||
param_shape == flat_fp32_avg_grads.shape
|
param_shape == flat_fp32_avg_grads.shape
|
||||||
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
), f"fp32 param and grad have different shape {param_shape} vs {flat_fp32_avg_grads.shape}"
|
||||||
|
|
||||||
|
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
||||||
|
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
||||||
|
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
||||||
|
else:
|
||||||
|
assert len(gradients) == len(self.optim.param_groups[group_id]["params"]), (
|
||||||
|
len(gradients),
|
||||||
|
len(self.optim.param_groups[group_id]["params"]),
|
||||||
|
)
|
||||||
|
for g, p in zip(gradients, self.optim.param_groups[group_id]["params"]):
|
||||||
|
p.grad = g
|
||||||
|
|
||||||
single_grad_partition_groups.append(flat_fp32_avg_grads)
|
|
||||||
device = self._fp32_flat_param_groups_of_current_rank[group_id].device
|
|
||||||
self._fp32_flat_param_groups_of_current_rank[group_id].grad = flat_fp32_avg_grads.to(device)
|
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
self._grad_store._averaged_gradients[group_id] = []
|
||||||
self._grad_store._averaged_gradients[group_id] = []
|
self._grad_store._averaged_gradients[group_id] = []
|
||||||
|
|
||||||
|
@ -576,8 +604,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
global_norm = sum(norm_groups) ** 0.5
|
global_norm = sum(norm_groups) ** 0.5
|
||||||
|
|
||||||
# the following operations are performed only on the rank to which parameters are assigned.
|
# the following operations are performed only on the rank to which parameters are assigned.
|
||||||
if len(single_grad_partition_groups) != 0:
|
if self.use_fp16:
|
||||||
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
|
if len(single_grad_partition_groups) != 0:
|
||||||
|
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale)
|
||||||
|
|
||||||
timer("cal_norm").stop()
|
timer("cal_norm").stop()
|
||||||
# update the parameters
|
# update the parameters
|
||||||
|
@ -588,15 +617,16 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if self.has_params:
|
if self.has_params:
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
# release the fp32 grad
|
# release the fp32 grad
|
||||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
if self.use_fp16:
|
||||||
# update fp16 partition updated by the current rank
|
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
# update fp16 partition updated by the current rank
|
||||||
if self.param_group_has_params[group_id]:
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
if self.param_group_has_params[group_id]:
|
||||||
rank=self._zero_local_rank, group_id=group_id
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
)
|
rank=self._zero_local_rank, group_id=group_id
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
)
|
||||||
fp16_param.data.copy_(fp32_param)
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
|
fp16_param.data.copy_(fp32_param)
|
||||||
|
|
||||||
# TODO: support broadcast overlap
|
# TODO: support broadcast overlap
|
||||||
self.broadcast_params(overlap=False)
|
self.broadcast_params(overlap=False)
|
||||||
|
@ -614,8 +644,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# The following operations are performed only on the rank to which parameters are assigned.
|
# The following operations are performed only on the rank to which parameters are assigned.
|
||||||
if rank not in self.param_group_no_params_ranks[group_id]:
|
if rank not in self.param_group_no_params_ranks[group_id]:
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(rank=rank, group_id=group_id)
|
||||||
# grank = gpc.get_ranks_in_group(group_type)[rank] # need to convert to the global rank
|
|
||||||
# assert grank == rank, f"{grank} == {rank}"
|
|
||||||
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
g_rank = gpc.get_ranks_in_group(self._broadcast_parallel_mode)[rank]
|
||||||
handle = dist.broadcast(
|
handle = dist.broadcast(
|
||||||
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
|
fp16_param, src=g_rank, group=gpc.get_group(ParallelMode.ZERO1), async_op=True
|
||||||
|
@ -667,48 +695,52 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
states = {}
|
states = {}
|
||||||
grad_scaler = self.grad_scaler.state_dict()
|
|
||||||
states["grad_scaler"] = grad_scaler
|
|
||||||
optim_states = self.optim.state_dict()
|
optim_states = self.optim.state_dict()
|
||||||
states["base_optim_states"] = optim_states
|
states["base_optim_states"] = optim_states
|
||||||
|
|
||||||
flat_fp32_weights = {}
|
if self.use_fp16:
|
||||||
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
|
grad_scaler = self.grad_scaler.state_dict()
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
states["grad_scaler"] = grad_scaler
|
||||||
assert param.grad is None
|
|
||||||
flat_fp32_weights[group_id] = param
|
flat_fp32_weights = {}
|
||||||
states["flat_fp32_weights"] = flat_fp32_weights
|
for group_id, param in self._fp32_flat_param_groups_of_current_rank.items():
|
||||||
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||||
|
assert param.grad is None
|
||||||
|
flat_fp32_weights[group_id] = param
|
||||||
|
states["flat_fp32_weights"] = flat_fp32_weights
|
||||||
states["zero_devide_optim_plan"] = self.params_per_rank_id_dict
|
states["zero_devide_optim_plan"] = self.params_per_rank_id_dict
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def load_state_dict(self, states):
|
def load_state_dict(self, states):
|
||||||
# TODO: Need to take into account the change in the number of DP.
|
# TODO: Need to take into account the change in the number of DP.
|
||||||
assert "grad_scaler" in states, "Not found grad_scaler state!"
|
|
||||||
grad_scaler = states["grad_scaler"]
|
|
||||||
self.grad_scaler.load_state_dict(grad_scaler)
|
|
||||||
optim_states = states["base_optim_states"]
|
optim_states = states["base_optim_states"]
|
||||||
self.optim.load_state_dict(optim_states)
|
self.optim.load_state_dict(optim_states)
|
||||||
|
|
||||||
# load fp32 model weight.
|
if self.use_fp16:
|
||||||
flat_fp32_weights = states["flat_fp32_weights"]
|
assert "grad_scaler" in states, "Not found grad_scaler state!"
|
||||||
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
|
grad_scaler = states["grad_scaler"]
|
||||||
for group_id, param in flat_fp32_weights.items():
|
self.grad_scaler.load_state_dict(grad_scaler)
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
|
||||||
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
|
||||||
assert (
|
|
||||||
self_param.shape == param.shape
|
|
||||||
), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}"
|
|
||||||
self_param.data.copy_(param.data)
|
|
||||||
|
|
||||||
# Load the fp16 model weights.
|
# load fp32 model weight.
|
||||||
for group_id in range(len(self._fp16_param_groups)):
|
flat_fp32_weights = states["flat_fp32_weights"]
|
||||||
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank)
|
||||||
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
for group_id, param in flat_fp32_weights.items():
|
||||||
rank=self._zero_local_rank, group_id=group_id
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||||
)
|
self_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
assert (
|
||||||
fp16_param.data.copy_(fp32_param)
|
self_param.shape == param.shape
|
||||||
|
), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}"
|
||||||
|
self_param.data.copy_(param.data)
|
||||||
|
|
||||||
|
# Load the fp16 model weights.
|
||||||
|
for group_id in range(len(self._fp16_param_groups)):
|
||||||
|
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
|
||||||
|
fp16_param = self._param_store.get_flat_fp16_param_by_rank_group(
|
||||||
|
rank=self._zero_local_rank, group_id=group_id
|
||||||
|
)
|
||||||
|
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||||
|
fp16_param.data.copy_(fp32_param)
|
||||||
|
|
||||||
if "zero_devide_optim_plan" in states:
|
if "zero_devide_optim_plan" in states:
|
||||||
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
|
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
|
||||||
|
|
5
train.py
5
train.py
|
@ -282,7 +282,10 @@ def initialize_optimizer(model: nn.Module):
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = HybridZeroOptimizer(
|
optimizer = HybridZeroOptimizer(
|
||||||
naive_optimizer, grad_scal_cfg=gpc.config.grad_scaler, zero_cfg=gpc.config.hybrid_zero_optimizer
|
naive_optimizer,
|
||||||
|
grad_scal_cfg=gpc.config.grad_scaler,
|
||||||
|
zero_cfg=gpc.config.hybrid_zero_optimizer,
|
||||||
|
use_fp16= gpc.config.model.dtype is torch.float32,
|
||||||
)
|
)
|
||||||
|
|
||||||
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
beta2_scheduler = Beta2Scheduler(optimizer=naive_optimizer, **gpc.config.beta2_scheduler)
|
||||||
|
|
Loading…
Reference in New Issue