mirror of https://github.com/InternLM/InternLM
fix(internlm/train/training_internlm.py): remove set IS_TENSOR_PARALLEL attr
parent
1b71b19e23
commit
2e94870967
|
@ -78,8 +78,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
loss.backward(retain_graph=retain_graph)
|
||||
|
||||
def _compute_norm_with_fsdp_flatten(self, group_id):
|
||||
params = [p for p in self._fp16_param_groups[group_id] if p.storage().size() != 0]
|
||||
gradients = [p.grad for p in params if p.storage().size() != 0]
|
||||
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
|
||||
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
|
||||
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
|
||||
|
||||
return norm_group
|
||||
|
@ -128,11 +128,11 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
# create gradient for fp32 params
|
||||
for group_idx in range(len(self.param_groups)):
|
||||
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
|
||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
|
||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
|
||||
|
||||
device = self._fp32_param_tensor_groups[group_idx][0].device
|
||||
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
|
||||
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||
for p, g in zip(nonzero_fp32, grad_fp32):
|
||||
p.grad = g.to(device)
|
||||
|
||||
|
@ -143,8 +143,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
self.zero_grad()
|
||||
|
||||
for group_idx in range(len(self._fp16_param_groups)):
|
||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
|
||||
fp32_tensor_params = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
|
||||
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
|
||||
fp32_tensor_params = [
|
||||
p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0
|
||||
]
|
||||
# release fp32 grad
|
||||
release_param_grad(fp32_tensor_params)
|
||||
# update fp16 param
|
||||
|
@ -177,7 +179,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
|
|||
|
||||
for group_id, param in self._fp32_param_tensor_groups.items():
|
||||
for p in param:
|
||||
if p.storage().size() != 0:
|
||||
if p.untyped_storage().size() != 0:
|
||||
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
|
||||
|
||||
def state_dict(self):
|
||||
|
|
|
@ -108,15 +108,6 @@ def initialize_model():
|
|||
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
||||
RMSNorm = try_import_RMSNorm()
|
||||
if gpc.config.parallel.use_fsdp:
|
||||
# pre-save info for tensor parallel
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
tp_dict = dict()
|
||||
for name, param in model.named_parameters():
|
||||
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
|
||||
tp_dict.update({name.replace("model.", ""): True})
|
||||
else:
|
||||
tp_dict.update({name.replace("model.", ""): False})
|
||||
|
||||
# set wrap_policy for fsdp wrap
|
||||
transformer_wrap_policy = functools.partial(
|
||||
transformer_auto_wrap_policy,
|
||||
|
@ -133,15 +124,9 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
|||
forward_prefetch=True,
|
||||
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
|
||||
limit_all_gathers=True,
|
||||
use_orig_params=False,
|
||||
use_orig_params=True,
|
||||
)
|
||||
|
||||
# re-set attribute for fsdp module with tensor parallel
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for (name, param), pre in zip(model.named_parameters(), tp_dict):
|
||||
if pre in name and tp_dict[pre]:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue