mirror of https://github.com/InternLM/InternLM
fix(fsdp): add mix-precision training
parent
85c6ed6473
commit
420e883d76
|
@ -112,6 +112,28 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
self._clip_grad_norm = zero_cfg.clip_grad_norm
|
||||||
self.use_fsdp = gpc.config.parallel.use_fsdp
|
self.use_fsdp = gpc.config.parallel.use_fsdp
|
||||||
|
|
||||||
|
# mark whether a module is part of TP or not
|
||||||
|
is_tensor_parallel_dict = dict()
|
||||||
|
|
||||||
|
# fp16 and fp32 params
|
||||||
|
# fp16 share mem space with model.FlatParam, fp32 share mem space with optim.param_group
|
||||||
|
self._fp16_param_groups = dict()
|
||||||
|
self._fp32_param_tensor_groups = dict()
|
||||||
|
|
||||||
|
# init fp16 and fp32 params
|
||||||
|
for group_idx, param_group in enumerate(self.optim.param_groups):
|
||||||
|
group_params = param_group["params"]
|
||||||
|
|
||||||
|
# fp16 FlatParam storage
|
||||||
|
self._fp16_param_groups[group_idx] = group_params
|
||||||
|
|
||||||
|
# create copy of fp32 weight
|
||||||
|
fp32_tensor_param = [param.data.float().requires_grad_(True) for param in group_params]
|
||||||
|
self._fp32_param_tensor_groups[group_idx] = fp32_tensor_param
|
||||||
|
|
||||||
|
# replace
|
||||||
|
param_group["params"] = fp32_tensor_param
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def loss_scale(self):
|
def loss_scale(self):
|
||||||
return self.grad_scaler.scale
|
return self.grad_scaler.scale
|
||||||
|
@ -120,12 +142,28 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
loss = self.loss_scale * loss
|
loss = self.loss_scale * loss
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
|
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||||
|
params = self._fp16_param_groups[group_id]
|
||||||
|
gradients = [p.grad for p in params]
|
||||||
|
norm_group = compute_norm(
|
||||||
|
gradients=gradients,
|
||||||
|
parameters=params,
|
||||||
|
last_stage=True
|
||||||
|
)
|
||||||
|
|
||||||
|
return norm_group
|
||||||
|
|
||||||
|
def zero_grad(self):
|
||||||
|
for _, param_group in self._fp16_param_groups.items():
|
||||||
|
for param in param_group:
|
||||||
|
param.grad = None
|
||||||
|
|
||||||
def step(self):
|
def step(self):
|
||||||
# in case that fsdp-zero3 size is not equal to dp size
|
# in case that fsdp-zero3 size is not equal to dp size
|
||||||
# FSDP module will only reduce gradient within FSDP process group
|
# FSDP module will only reduce gradient within FSDP process group
|
||||||
# so manually reduce grad is essential between two parallel FSDP process group
|
# so manually reduce grad is essential between two parallel FSDP process group
|
||||||
for group_idx in range(len(self.param_groups)):
|
for group_idx in range(len(self.param_groups)):
|
||||||
params = self.param_groups[group_idx]["params"]
|
params = self._fp16_param_groups[group_idx]
|
||||||
for param in params:
|
for param in params:
|
||||||
if param.requires_grad:
|
if param.requires_grad:
|
||||||
reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
|
reduce_tensor(tensor=param.grad, parallel_mode=ParallelMode.ZERO3_DP)
|
||||||
|
@ -134,13 +172,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
found_inf = False
|
found_inf = False
|
||||||
norm_groups = []
|
norm_groups = []
|
||||||
for group_idx in range(len(self.param_groups)):
|
for group_idx in range(len(self.param_groups)):
|
||||||
params = self.param_groups[group_idx]["params"]
|
norm_group = self._compute_norm_with_fsdp_flatten(group_idx)
|
||||||
gradients = [p.grad for p in params]
|
|
||||||
norm_group = compute_norm(
|
|
||||||
gradients=gradients,
|
|
||||||
parameters=params,
|
|
||||||
last_stage=True
|
|
||||||
)
|
|
||||||
if norm_group == -1:
|
if norm_group == -1:
|
||||||
found_inf = True
|
found_inf = True
|
||||||
break
|
break
|
||||||
|
@ -157,15 +189,32 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
if self._clip_grad_norm > 0:
|
if self._clip_grad_norm > 0:
|
||||||
global_norm = sum(norm_groups) ** 0.5
|
global_norm = sum(norm_groups) ** 0.5
|
||||||
|
|
||||||
|
# create gradient for fp32 params
|
||||||
|
for group_idx in range(len(self.param_groups)):
|
||||||
|
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
|
||||||
|
fp16_params = self._fp16_param_groups[group_idx]
|
||||||
|
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
|
||||||
|
|
||||||
|
device = self._fp32_param_tensor_groups[group_idx][0].device
|
||||||
|
for p, g in zip(self._fp32_param_tensor_groups[group_idx], grad_fp32):
|
||||||
|
p.grad = g.to(device)
|
||||||
|
|
||||||
# unscale
|
# unscale
|
||||||
for group_idx in range(len(self.param_groups)):
|
for group_idx in range(len(self.param_groups)):
|
||||||
params = self.param_groups[group_idx]["params"]
|
params = self._fp32_param_tensor_groups[group_idx]
|
||||||
for p in params:
|
for p in params:
|
||||||
self._unscale_and_clip_grads(p.grad, global_norm, loss_scale)
|
self._unscale_and_clip_grads(p.grad, global_norm, loss_scale)
|
||||||
|
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
|
|
||||||
|
# update fp16 param
|
||||||
|
for group_idx in range(len(self._fp16_param_groups)):
|
||||||
|
fp16_params = self._fp16_param_groups[group_idx]
|
||||||
|
fp32_tensor_params = self._fp32_param_tensor_groups[group_idx]
|
||||||
|
for p, q in zip(fp16_params, fp32_tensor_params):
|
||||||
|
p.data.copy_(q)
|
||||||
|
|
||||||
return True, [global_norm / loss_scale]
|
return True, [global_norm / loss_scale]
|
||||||
|
|
||||||
def clip_grad_norm(self, model, max_norm):
|
def clip_grad_norm(self, model, max_norm):
|
||||||
|
@ -196,6 +245,11 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
optim_states = self.optim.state_dict()
|
optim_states = self.optim.state_dict()
|
||||||
states["base_optim_states"] = optim_states
|
states["base_optim_states"] = optim_states
|
||||||
|
|
||||||
|
flat_fp32_weights = {}
|
||||||
|
for group_idx, param in self._fp32_param_tensor_groups.items():
|
||||||
|
flat_fp32_weights[group_idx] = param
|
||||||
|
states["flat_fp32_weights"] = flat_fp32_weights
|
||||||
|
|
||||||
return states
|
return states
|
||||||
|
|
||||||
def load_state_dict(self, states):
|
def load_state_dict(self, states):
|
||||||
|
@ -205,6 +259,24 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
optim_states = states["base_optim_states"]
|
optim_states = states["base_optim_states"]
|
||||||
self.optim.load_state_dict(optim_states)
|
self.optim.load_state_dict(optim_states)
|
||||||
|
|
||||||
|
# load fp32 optimizer weight
|
||||||
|
flat_fp32_weights = states["flat_fp32_weights"]
|
||||||
|
assert set(flat_fp32_weights.keys()) == set(self._fp32_param_tensor_groups)
|
||||||
|
for group_idx, param in flat_fp32_weights.items():
|
||||||
|
self_param = self._fp32_param_tensor_groups[group_idx]
|
||||||
|
assert (
|
||||||
|
len(self_param) == len(param)
|
||||||
|
), f"The number of flat tensor is inconsistent, {len(self_param)} != {len(param)}"
|
||||||
|
for p, q in zip(self_param, param):
|
||||||
|
p.data.copy_(q.data)
|
||||||
|
|
||||||
|
# load fp16 model weight
|
||||||
|
for group_idx, param in flat_fp32_weights.items():
|
||||||
|
fp16_param = self._fp16_param_groups[group_idx]
|
||||||
|
fp32_param = self._fp32_param_tensor_groups[group_idx]
|
||||||
|
for p, q in zip(fp16_param, fp32_param):
|
||||||
|
p.data.copy_(q.data)
|
||||||
|
|
||||||
|
|
||||||
class HybridZeroOptimizer(BaseOptimizer):
|
class HybridZeroOptimizer(BaseOptimizer):
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue