mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Fix serialization error with Tensor Parallel state saving (#5018)
* Fix serialization error with Tensor Parallel state saving * Refactor state_dict CPU transfer using tree_mappull/4942/head^2
parent
724441279b
commit
a4489384d5
|
@ -11,6 +11,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils._pytree import tree_map
|
||||||
|
|
||||||
from colossalai.tensor.d_tensor import (
|
from colossalai.tensor.d_tensor import (
|
||||||
is_customized_distributed_tensor,
|
is_customized_distributed_tensor,
|
||||||
|
@ -293,7 +294,6 @@ def shard_optimizer_checkpoint(state_dict: dict, max_shard_size: int = 1024) ->
|
||||||
# Helper functions for saving state dict
|
# Helper functions for saving state dict
|
||||||
# ======================================
|
# ======================================
|
||||||
|
|
||||||
|
|
||||||
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
|
def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors: bool) -> None:
|
||||||
"""
|
"""
|
||||||
Save state dict to checkpoint.
|
Save state dict to checkpoint.
|
||||||
|
@ -303,6 +303,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||||
checkpoint_file_path (str): path to the checkpoint file.
|
checkpoint_file_path (str): path to the checkpoint file.
|
||||||
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
use_safetensors (bool): whether to use safetensors to save the checkpoint.
|
||||||
"""
|
"""
|
||||||
|
# Move all tensors in the state_dict to CPU before saving to avoid serialization issues
|
||||||
|
state_dict_cpu = tree_map(lambda x: x.cpu() if torch.is_tensor(x) else x, state_dict)
|
||||||
|
|
||||||
if use_safetensors:
|
if use_safetensors:
|
||||||
assert is_safetensors_available(), "safetensors is not available."
|
assert is_safetensors_available(), "safetensors is not available."
|
||||||
assert checkpoint_file_path.endswith(
|
assert checkpoint_file_path.endswith(
|
||||||
|
@ -310,9 +313,9 @@ def save_state_dict(state_dict: dict, checkpoint_file_path: str, use_safetensors
|
||||||
), "safetensors only supports .safetensors suffix for checkpoint file."
|
), "safetensors only supports .safetensors suffix for checkpoint file."
|
||||||
from safetensors.torch import save_file as safe_save_file
|
from safetensors.torch import save_file as safe_save_file
|
||||||
|
|
||||||
safe_save_file(state_dict, checkpoint_file_path, metadata={"format": "pt"})
|
safe_save_file(state_dict_cpu, checkpoint_file_path, metadata={"format": "pt"})
|
||||||
else:
|
else:
|
||||||
torch.save(state_dict, checkpoint_file_path)
|
torch.save(state_dict_cpu, checkpoint_file_path)
|
||||||
|
|
||||||
|
|
||||||
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
def save_param_groups(state_dict: dict, group_file_path: str) -> None:
|
||||||
|
|
Loading…
Reference in New Issue