mirror of https://github.com/hpcaitech/ColossalAI
[optim] hotfix adam load (#6146)
* [optim] hotfix adam load * [checkpointio] fix optimizer async io * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [checkpointio] update test * [checkpointio] update test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/6148/head
parent
5caad13055
commit
cf519dac6a
|
@ -142,7 +142,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
||||||
from colossalai.utils.safetensors import save_nested
|
from colossalai.utils.safetensors import save_nested
|
||||||
|
|
||||||
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
f_writer = AsyncFileWriter(fp=open(checkpoint, "wb"), n_entries=self.N_WRITE_ENTRIES, backend="pthread")
|
||||||
save_nested(f_writer, state_dict["state"], {"param_groups": state_dict["param_groups"]})
|
save_nested(f_writer, state_dict)
|
||||||
self.async_writers.append(f_writer)
|
self.async_writers.append(f_writer)
|
||||||
else:
|
else:
|
||||||
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
save_state_dict(state_dict, checkpoint, use_safetensors=False)
|
||||||
|
|
|
@ -81,6 +81,14 @@ class CPUAdam(NVMeOptimizer):
|
||||||
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
|
# if you find yourself stuck here, make sure that you install colossalai with BUILD_EXT=1 specification
|
||||||
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
self.cpu_adam_op = cpu_adam.CPUAdamOptimizer(lr, betas[0], betas[1], eps, weight_decay, adamw_mode)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict):
|
||||||
|
super().load_state_dict(state_dict)
|
||||||
|
for group in self.param_groups:
|
||||||
|
for p in group["params"]:
|
||||||
|
state = self.state[p]
|
||||||
|
if "step" in state and isinstance(state["step"], torch.Tensor):
|
||||||
|
state["step"] = int(state["step"].item())
|
||||||
|
|
||||||
def torch_adam_update(
|
def torch_adam_update(
|
||||||
self,
|
self,
|
||||||
data,
|
data,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any, List, OrderedDict, Tuple
|
from typing import Any, List, OrderedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
@ -78,9 +78,7 @@ def check_state_dict_equal(
|
||||||
v1 = v1.to(v2.dtype)
|
v1 = v1.to(v2.dtype)
|
||||||
assert_close_loose(v1, v2)
|
assert_close_loose(v1, v2)
|
||||||
else:
|
else:
|
||||||
if isinstance(v1, Tuple) and not isinstance(v2, Tuple):
|
assert v1 == v2, f"{v1} not equals to {v2}"
|
||||||
v2 = tuple(v2)
|
|
||||||
assert v1 == v2, f"{v1} not equals to {v2}. {type(v1)}, {type(v2)}"
|
|
||||||
|
|
||||||
|
|
||||||
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
def check_state_dict_equal_pytree(d1: OrderedDict, d2: OrderedDict, ignore_device: bool = True):
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
# a python safetensors serializer modified from https://github.com/huggingface/safetensors/blob/41bd1acf38ad28ac559522d40596c6c802f79453/safetensors/src/tensor.rs#L214
|
||||||
import json
|
import json
|
||||||
import warnings
|
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -12,6 +11,26 @@ try:
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||||
_TYPES_INV = {v: k for k, v in _TYPES.items()}
|
_TYPES_INV = {v: k for k, v in _TYPES.items()}
|
||||||
|
import io
|
||||||
|
|
||||||
|
from torch.distributed.distributed_c10d import _pickler, _unpickler
|
||||||
|
|
||||||
|
|
||||||
|
def _object_to_tensor(obj, device):
|
||||||
|
f = io.BytesIO()
|
||||||
|
_pickler(f).dump(obj)
|
||||||
|
byte_storage = torch.ByteStorage._from_buffer(f.getvalue()) # type: ignore[attr-defined]
|
||||||
|
# Do not replace `torch.ByteTensor` or `torch.LongTensor` with torch.tensor and specifying dtype.
|
||||||
|
# Otherwise, it will casue 100X slowdown.
|
||||||
|
# See: https://github.com/pytorch/pytorch/issues/65696
|
||||||
|
byte_tensor = torch.ByteTensor(byte_storage).to(device)
|
||||||
|
return byte_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _tensor_to_object(tensor, tensor_size):
|
||||||
|
tensor = tensor.cpu()
|
||||||
|
buf = tensor.numpy().tobytes()[:tensor_size]
|
||||||
|
return _unpickler(io.BytesIO(buf)).load()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -28,49 +47,68 @@ class PreparedData:
|
||||||
offset: int
|
offset: int
|
||||||
|
|
||||||
|
|
||||||
def flatten_dict(nested_dict, parent_key="", separator="^"):
|
def _cast_to_tensor(obj):
|
||||||
"""
|
if isinstance(obj, torch.Tensor):
|
||||||
Flatten a nested dictionary, generating a flattened dictionary where the keys are joined by the specified separator.
|
return obj
|
||||||
|
return _object_to_tensor(obj, "cpu")
|
||||||
nested_dict: The input nested dictionary.
|
|
||||||
parent_key: The parent key currently being processed.
|
|
||||||
separator: The separator used to join keys, default is '_', but can be customized to another symbol. :return: A flattened dictionary."
|
|
||||||
"""
|
|
||||||
items = []
|
|
||||||
for k, v in nested_dict.items():
|
|
||||||
new_key = f"{parent_key}{separator}{k}" if parent_key else str(k)
|
|
||||||
if isinstance(v, dict):
|
|
||||||
items.extend(flatten_dict(v, new_key, separator).items())
|
|
||||||
else:
|
|
||||||
v = torch.tensor(v, dtype=torch.float16) if not isinstance(v, torch.Tensor) else v
|
|
||||||
items.append((new_key, v))
|
|
||||||
|
|
||||||
return dict(items)
|
|
||||||
|
|
||||||
|
|
||||||
def unflatten_dict(flattened_dict, separator="^"):
|
def _cast_to_object(tensor: torch.Tensor):
|
||||||
"""
|
return _tensor_to_object(tensor, tensor.numel() * tensor.element_size())
|
||||||
Restore a flattened dictionary back to a multi-level nested dictionary.
|
|
||||||
|
|
||||||
flattened_dict: The flattened dictionary.
|
|
||||||
separator: The separator used during flattening, default is '_', but can be customized to another symbol. :return: The restored nested dictionary.
|
|
||||||
"""
|
|
||||||
nested_dict = {}
|
|
||||||
for key, value in flattened_dict.items():
|
|
||||||
keys = key.split(separator)
|
|
||||||
try:
|
|
||||||
keys[0] = int(keys[0])
|
|
||||||
except ValueError:
|
|
||||||
warnings.warn(f"{key[0]} can't convert to integer")
|
|
||||||
d = nested_dict
|
|
||||||
for part in keys[:-1]:
|
|
||||||
if part not in d:
|
|
||||||
d[part] = {}
|
|
||||||
d = d[part]
|
|
||||||
assert isinstance(value, torch.Tensor)
|
|
||||||
d[keys[-1]] = value
|
|
||||||
|
|
||||||
return nested_dict
|
def _flatten_optim_state_dict(state_dict: dict, seperator: str = ".") -> Tuple[dict, Optional[dict]]:
|
||||||
|
flat_dict = {}
|
||||||
|
non_tensor_keys = []
|
||||||
|
if "state" in state_dict:
|
||||||
|
# 3-level dict
|
||||||
|
states = state_dict["state"]
|
||||||
|
else:
|
||||||
|
# 2-level dict, usually for optimizer state dict shard
|
||||||
|
states = state_dict
|
||||||
|
|
||||||
|
for idx, d in states.items():
|
||||||
|
for k, v in d.items():
|
||||||
|
nested_key = f"state{seperator}{idx}{seperator}{k}"
|
||||||
|
if not isinstance(v, torch.Tensor):
|
||||||
|
non_tensor_keys.append(nested_key)
|
||||||
|
flat_dict[nested_key] = _cast_to_tensor(v)
|
||||||
|
if "param_groups" in state_dict:
|
||||||
|
flat_dict["param_groups"] = _cast_to_tensor(state_dict["param_groups"])
|
||||||
|
non_tensor_keys.append("param_groups")
|
||||||
|
if len(non_tensor_keys) > 0:
|
||||||
|
metadata = {"non_tensor_keys": non_tensor_keys}
|
||||||
|
else:
|
||||||
|
metadata = None
|
||||||
|
return flat_dict, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _unflatten_optim_state_dict(flat_dict: dict, metadata: Optional[dict] = None, seperator: str = "."):
|
||||||
|
state_dict = {}
|
||||||
|
if metadata is not None:
|
||||||
|
non_tensor_keys = json.loads(metadata["non_tensor_keys"])
|
||||||
|
else:
|
||||||
|
non_tensor_keys = []
|
||||||
|
flat_dict = {k: _cast_to_object(v) if k in non_tensor_keys else v for k, v in flat_dict.items()}
|
||||||
|
if "param_groups" in flat_dict:
|
||||||
|
# 3-level dict
|
||||||
|
state_dict["param_groups"] = flat_dict.pop("param_groups")
|
||||||
|
state_dict["state"] = {}
|
||||||
|
states = state_dict["state"]
|
||||||
|
else:
|
||||||
|
# 2-level dict, usually for optimizer state dict shard
|
||||||
|
states = state_dict
|
||||||
|
|
||||||
|
for k, v in flat_dict.items():
|
||||||
|
parts = k.split(seperator)
|
||||||
|
assert len(parts) == 3 and parts[0] == "state"
|
||||||
|
idx = int(parts[1])
|
||||||
|
key = parts[2]
|
||||||
|
if idx not in states:
|
||||||
|
states[idx] = {}
|
||||||
|
states[idx][key] = v
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
|
||||||
def prepare(
|
def prepare(
|
||||||
|
@ -124,10 +162,8 @@ def save(
|
||||||
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
f_writer.write_raw(tensor, tensor.data_ptr(), tensor.numel() * tensor.element_size(), f_writer.offset)
|
||||||
|
|
||||||
|
|
||||||
def save_nested(
|
def save_nested(f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor]) -> None:
|
||||||
f_writer: AsyncFileWriter, state_dict: Dict[str, torch.Tensor], metadata: Optional[Dict[str, str]] = None
|
flatten_data, metadata = _flatten_optim_state_dict(state_dict)
|
||||||
) -> None:
|
|
||||||
flatten_data = flatten_dict(state_dict)
|
|
||||||
save(f_writer, flatten_data, metadata)
|
save(f_writer, flatten_data, metadata)
|
||||||
|
|
||||||
|
|
||||||
|
@ -154,10 +190,5 @@ def load_flat(checkpoint_path):
|
||||||
with safe_open(checkpoint_path, framework="pt") as f:
|
with safe_open(checkpoint_path, framework="pt") as f:
|
||||||
metadata = f.metadata()
|
metadata = f.metadata()
|
||||||
state_dict_load = load_file(checkpoint_path)
|
state_dict_load = load_file(checkpoint_path)
|
||||||
state_dict = unflatten_dict(state_dict_load)
|
state_dict = _unflatten_optim_state_dict(state_dict_load, metadata)
|
||||||
if metadata is None:
|
return state_dict
|
||||||
return state_dict
|
|
||||||
metadata = dict(map(lambda item: (item[0], json.loads(item[1])), metadata.items()))
|
|
||||||
combined_state_dict = {"state": state_dict}
|
|
||||||
combined_state_dict.update(metadata)
|
|
||||||
return combined_state_dict
|
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import tempfile
|
import tempfile
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
from colossalai.utils.safetensors import load_flat, save_nested
|
from colossalai.utils.safetensors import load_flat, move_and_save, save, save_nested
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from tensornvme.async_file_io import AsyncFileWriter
|
from tensornvme.async_file_io import AsyncFileWriter
|
||||||
|
@ -11,17 +11,29 @@ except ModuleNotFoundError:
|
||||||
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
raise ModuleNotFoundError("Please install tensornvme to use NVMeOptimizer")
|
||||||
|
|
||||||
from colossalai.testing import check_state_dict_equal
|
from colossalai.testing import check_state_dict_equal
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
|
||||||
def test_save_load():
|
def test_save_load():
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
optimizer_state_dict = {
|
optimizer_state_dict = {
|
||||||
0: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
"state": {
|
||||||
1: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
0: {
|
||||||
2: {"step": torch.tensor(1.0), "exp_avg": torch.rand((1024, 1024)), "exp_avg_sq": torch.rand((1024, 1024))},
|
"step": torch.tensor(1.0),
|
||||||
}
|
"exp_avg": torch.rand((1024, 1024)),
|
||||||
# group_dict = {"param_groups": [0, 1, 2]}
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||||
group_dict = {
|
},
|
||||||
|
1: {
|
||||||
|
"step": torch.tensor(1.0),
|
||||||
|
"exp_avg": torch.rand((1024, 1024)),
|
||||||
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||||
|
},
|
||||||
|
2: {
|
||||||
|
"step": torch.tensor(1.0),
|
||||||
|
"exp_avg": torch.rand((1024, 1024)),
|
||||||
|
"exp_avg_sq": torch.rand((1024, 1024)),
|
||||||
|
},
|
||||||
|
},
|
||||||
"param_groups": [
|
"param_groups": [
|
||||||
{
|
{
|
||||||
"lr": 0.001,
|
"lr": 0.001,
|
||||||
|
@ -94,22 +106,26 @@ def test_save_load():
|
||||||
61,
|
61,
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
]
|
],
|
||||||
}
|
}
|
||||||
metadata = deepcopy(group_dict)
|
|
||||||
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
optimizer_saved_path = f"{tempdir}/save_optimizer.safetensors"
|
||||||
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
f_writer = AsyncFileWriter(fp=open(optimizer_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||||
|
save_nested(f_writer, optimizer_state_dict)
|
||||||
save_nested(f_writer, optimizer_state_dict, metadata)
|
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
f_writer.fp.close()
|
f_writer.fp.close()
|
||||||
|
|
||||||
load_state_dict = load_flat(optimizer_saved_path)
|
load_state_dict = load_flat(optimizer_saved_path)
|
||||||
state_dict = load_state_dict["state"]
|
check_state_dict_equal(load_state_dict, optimizer_state_dict)
|
||||||
group = {"param_groups": load_state_dict["param_groups"]}
|
|
||||||
check_state_dict_equal(optimizer_state_dict, state_dict)
|
optimizer_shard_saved_path = f"{tempdir}/save_optimizer_shard.safetensors"
|
||||||
check_state_dict_equal(group_dict, group)
|
f_writer = AsyncFileWriter(fp=open(optimizer_shard_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||||
|
save_nested(f_writer, optimizer_state_dict["state"])
|
||||||
|
f_writer.sync_before_step()
|
||||||
|
f_writer.synchronize()
|
||||||
|
f_writer.fp.close()
|
||||||
|
load_state_dict_shard = load_flat(optimizer_shard_saved_path)
|
||||||
|
check_state_dict_equal(load_state_dict_shard, optimizer_state_dict["state"])
|
||||||
|
|
||||||
model_state_dict = {
|
model_state_dict = {
|
||||||
"module.weight0": torch.rand((1024, 1024)),
|
"module.weight0": torch.rand((1024, 1024)),
|
||||||
|
@ -118,10 +134,20 @@ def test_save_load():
|
||||||
}
|
}
|
||||||
model_saved_path = f"{tempdir}/save_model.safetensors"
|
model_saved_path = f"{tempdir}/save_model.safetensors"
|
||||||
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||||
save_nested(f_writer, model_state_dict)
|
save(f_writer, model_state_dict)
|
||||||
f_writer.sync_before_step()
|
f_writer.sync_before_step()
|
||||||
f_writer.synchronize()
|
f_writer.synchronize()
|
||||||
f_writer.fp.close()
|
f_writer.fp.close()
|
||||||
|
load_state_dict = load_file(model_saved_path)
|
||||||
load_state_dict = load_flat(model_saved_path)
|
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||||
|
|
||||||
|
model_state_dict_cuda = {k: v.to(get_current_device()) for k, v in model_state_dict.items()}
|
||||||
|
model_state_pinned = {k: v.pin_memory() for k, v in model_state_dict.items()}
|
||||||
|
model_saved_path = f"{tempdir}/save_model_cuda.safetensors"
|
||||||
|
f_writer = AsyncFileWriter(fp=open(model_saved_path, "wb"), n_entries=191, backend="pthread")
|
||||||
|
move_and_save(f_writer, model_state_dict_cuda, model_state_pinned)
|
||||||
|
f_writer.sync_before_step()
|
||||||
|
f_writer.synchronize()
|
||||||
|
f_writer.fp.close()
|
||||||
|
load_state_dict = load_file(model_saved_path)
|
||||||
check_state_dict_equal(model_state_dict, load_state_dict)
|
check_state_dict_equal(model_state_dict, load_state_dict)
|
||||||
|
|
Loading…
Reference in New Issue