[shardformer] Fix serialization error with Tensor Parallel state saving (#5018)

* Fix serialization error with Tensor Parallel state saving

* Refactor state_dict CPU transfer using tree_map
pull/4942/head^2
Jun Gao 2023-11-09 17:00:25 +08:00 committed by GitHub
parent 724441279b
commit a4489384d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 3 deletions

View File

@ -11,6 +11,7 @@ import torch
import torch.nn as nn
from packaging.version import Version
from torch.optim import Optimizer
from torch.utils._pytree import tree_map
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
@ -293,7 +294,6 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
# Helper functions for saving state dict
# ======================================
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
"""
Save state dict to checkpoint.
@ -303,6 +303,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
checkpoint_file_path (str): path to the checkpoint file.
use_safetensors (bool): whether to use safetensors to save the checkpoint.
"""
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
if use_safetensors:
assert is_safetensors_available(), "safetensors is not available."
assert checkpoint_file_path.endswith(
@ -310,9 +313,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
), "safetensors only supports .safetensors suffix for checkpoint file."
from safetensors.torch import save_file as safe_save_file
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
safe_save_file(state_dict_cpu, checkpoint_file_path, metadata={"format": "pt"})
else:
torch.save(state_dict, checkpoint_file_path)
torch.save(state_dict_cpu, checkpoint_file_path)
def save_param_groups(state_dict: dict, group_file_path: str) -> None: