[checkpointio] gather tensor before unpad it if the tensor is both padded and distributed (#6168)

feature/dist-ckp-io
Lemon Qin 2025-01-21 10:23:15 +08:00 committed by GitHub
parent 5b094a836b
commit 97e60cbbcb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 2 deletions

View File

@ -107,9 +107,9 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
if is_padded_tensor(param_):
param_ = to_unpadded_tensor(param_)
if pinned_state_dicts is not None:
if (prefix + name) not in pinned_state_dicts:
pinned_state_dicts[prefix + name] = torch.empty_like(param_, pin_memory=True, device="cpu")