fix(internlm/train/training_internlm.py): remove set IS_TENSOR_PARALLEL attr

pull/293/head
huangting4201 2023-10-08 20:22:40 +08:00
parent 1b71b19e23
commit 2e94870967
2 changed files with 10 additions and 23 deletions

View File

@ -78,8 +78,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
loss.backward(retain_graph=retain_graph)
def _compute_norm_with_fsdp_flatten(self, group_id):
params = [p for p in self._fp16_param_groups[group_id] if p.storage().size() != 0]
gradients = [p.grad for p in params if p.storage().size() != 0]
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
return norm_group
@ -128,11 +128,11 @@ class FSDPadaptOptimizer(BaseOptimizer):
# create gradient for fp32 params
for group_idx in range(len(self.param_groups)):
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]
device = self._fp32_param_tensor_groups[group_idx][0].device
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
nonzero_fp32 = [p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0]
for p, g in zip(nonzero_fp32, grad_fp32):
p.grad = g.to(device)
@ -143,8 +143,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
self.zero_grad()
for group_idx in range(len(self._fp16_param_groups)):
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.storage().size() != 0]
fp32_tensor_params = [p for p in self._fp32_param_tensor_groups[group_idx] if p.storage().size() != 0]
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
fp32_tensor_params = [
p for p in self._fp32_param_tensor_groups[group_idx] if p.untyped_storage().size() != 0
]
# release fp32 grad
release_param_grad(fp32_tensor_params)
# update fp16 param
@ -177,7 +179,7 @@ class FSDPadaptOptimizer(BaseOptimizer):
for group_id, param in self._fp32_param_tensor_groups.items():
for p in param:
if p.storage().size() != 0:
if p.untyped_storage().size() != 0:
p.grad.data.mul_(1.0 / combined_scale_groups[group_id])
def state_dict(self):

View File

@ -108,15 +108,6 @@ def initialize_model():
def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
RMSNorm = try_import_RMSNorm()
if gpc.config.parallel.use_fsdp:
# pre-save info for tensor parallel
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
tp_dict = dict()
for name, param in model.named_parameters():
if hasattr(param, IS_TENSOR_PARALLEL) and getattr(param, IS_TENSOR_PARALLEL):
tp_dict.update({name.replace("model.", ""): True})
else:
tp_dict.update({name.replace("model.", ""): False})
# set wrap_policy for fsdp wrap
transformer_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
@ -133,15 +124,9 @@ def wrap_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
forward_prefetch=True,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
limit_all_gathers=True,
use_orig_params=False,
use_orig_params=True,
)
# re-set attribute for fsdp module with tensor parallel
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
for (name, param), pre in zip(model.named_parameters(), tp_dict):
if pre in name and tp_dict[pre]:
setattr(param, IS_TENSOR_PARALLEL, True)
return model