mirror of https://github.com/InternLM/InternLM
fix(internlm/train/training_internlm.py): remove set IS_TENSOR_PARALLEL attr
parent
1b71b19e23
commit
2e94870967
|
@ -78,8 +78,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
loss.backward(retain_graph=retain_graph)
|
loss.backward(retain_graph=retain_graph)
|
||||||
|
|
||||||
def _compute_norm_with_fsdp_flatten(self, group_id):
|
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||||
params = [p for p in self._fp16_param_groups[group_id] if p.storage().size() != 0]
|
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
|
||||||
gradients = [p.grad for p in params if p.storage().size() != 0]
|
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
|
||||||
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
||||||
|
|
||||||
return norm_group
|
return norm_group
|
||||||
|
@ -128,11 +128,11 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
# create gradient for fp32 params
|
# create gradient for fp32 params
|
||||||
for group_idx in range(len(self.param_groups)):
|
for group_idx in range(len(self.param_groups)):
|
||||||
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
|
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
|
||||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
|
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||||
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
|
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
|
||||||
|
|
||||||
device = self._fp32_param_tensor_groups[group_idx][0].device
|
device = self._fp32_param_tensor_groups[group_idx][0].device
|
||||||
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
|
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||||
for p, g in zip(nonzero_fp32, grad_fp32):
|
for p, g in zip(nonzero_fp32, grad_fp32):
|
||||||
p.grad = g.to(device)
|
p.grad = g.to(device)
|
||||||
|
|
||||||
|
@ -143,8 +143,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
|
|
||||||
for group_idx in range(len(self._fp16_param_groups)):
|
for group_idx in range(len(self._fp16_param_groups)):
|
||||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
|
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||||
fp32_tensor_params = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
|
fp32_tensor_params = [
|
||||||
|
p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0
|
||||||
|
]
|
||||||
# release fp32 grad
|
# release fp32 grad
|
||||||
release_param_grad(fp32_tensor_params)
|
release_param_grad(fp32_tensor_params)
|
||||||
# update fp16 param
|
# update fp16 param
|
||||||
|
@ -177,7 +179,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
for group_id, param in self._fp32_param_tensor_groups.items():
|
for group_id, param in self._fp32_param_tensor_groups.items():
|
||||||
for p in param:
|
for p in param:
|
||||||
if p.storage().size() != 0:
|
if p.untyped_storage().size() != 0:
|
||||||
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||||
|
|
||||||
def state_dict(self):
|
def state_dict(self):
|
||||||
|
|
|
@ -108,15 +108,6 @@ def initialize_model():
|
||||||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
RMSNorm = try_import_RMSNorm()
|
RMSNorm = try_import_RMSNorm()
|
||||||
if gpc.config.parallel.use_fsdp:
|
if gpc.config.parallel.use_fsdp:
|
||||||
# pre-save info for tensor parallel
|
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
|
||||||
tp_dict = dict()
|
|
||||||
for name, param in model.named_parameters():
|
|
||||||
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
|
|
||||||
tp_dict.update({name.replace("model.", ""): True})
|
|
||||||
else:
|
|
||||||
tp_dict.update({name.replace("model.", ""): False})
|
|
||||||
|
|
||||||
# set wrap_policy for fsdp wrap
|
# set wrap_policy for fsdp wrap
|
||||||
transformer_wrap_policy = functools.partial(
|
transformer_wrap_policy = functools.partial(
|
||||||
transformer_auto_wrap_policy,
|
transformer_auto_wrap_policy,
|
||||||
|
@ -133,15 +124,9 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||||
forward_prefetch=True,
|
forward_prefetch=True,
|
||||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||||
limit_all_gathers=True,
|
limit_all_gathers=True,
|
||||||
use_orig_params=False,
|
use_orig_params=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# re-set attribute for fsdp module with tensor parallel
|
|
||||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
|
||||||
for (name, param), pre in zip(model.named_parameters(), tp_dict):
|
|
||||||
if pre in name and tp_dict[pre]:
|
|
||||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue