mirror of https://github.com/hpcaitech/ColossalAI
[plugin] a workaround for zero plugins' optimizer checkpoint (#3780)
* [test] refactor torch ddp checkpoint test * [plugin] update low level zero optim checkpoint * [plugin] update gemini optim checkpointpull/4788/head
parent
60e6a154bc
commit
3c07a2846e
|
@ -52,8 +52,16 @@ class GeminiCheckpointIO(GeneralCheckpointIO):
|
|||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded
|
||||
warnings.warn('GeminiPlugin does not support save full optimizer checkpoint now. Save it on every process.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
warnings.warn(
|
||||
'GeminiPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().load_optimizer(optimizer, checkpoint)
|
||||
|
||||
def save_lr_scheduler(self, lr_scheduler: LRScheduler, checkpoint: str):
|
||||
"""
|
||||
Save model to checkpoint but only on master process.
|
||||
|
|
|
@ -9,7 +9,7 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
from torch.utils._pytree import tree_map
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIO
|
||||
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
|
||||
|
@ -32,8 +32,17 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|||
"""
|
||||
Save optimizer to checkpoint but only on master process.
|
||||
"""
|
||||
# TODO(ver217): optimizer state dict is sharded
|
||||
super().save_unsharded_optimizer(optimizer, checkpoint, gather_dtensor)
|
||||
# TODO(ver217): optimizer state dict is sharded, and cannot get full state dict now
|
||||
warnings.warn(
|
||||
'LowLevelZeroPlugin does not support save full optimizer checkpoint now. Save it on every process.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
GeneralCheckpointIO.save_unsharded_optimizer(self, optimizer, checkpoint, gather_dtensor)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, checkpoint: str):
|
||||
warnings.warn(
|
||||
'LowLevelZeroPlugin can only load optimizer checkpoint saved by itself with the same number of processes.')
|
||||
checkpoint = f'{checkpoint}.rank{self.coordinator.rank}'
|
||||
super().load_optimizer(optimizer, checkpoint)
|
||||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper):
|
||||
|
|
|
@ -1,87 +1,95 @@
|
|||
import tempfile
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import GeminiPlugin
|
||||
from colossalai.booster.plugin.gemini_plugin import GeminiCheckpointIO
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import check_state_dict_equal, parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.zero import ColoInitContext, ZeroDDP
|
||||
from colossalai.zero import ZeroDDP
|
||||
from colossalai.zero.gemini.chunk import ChunkManager, search_chunk_configuration
|
||||
from colossalai.zero.gemini.gemini_mgr import GeminiManager
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['bert'])
|
||||
@parameterize('use_safetensors', [True, False])
|
||||
@parameterize('model_name', ['transformers_bert_for_sequence_classification'])
|
||||
@parameterize('use_safetensors', [False, True])
|
||||
def exam_state_dict_with_origin(placement_policy, model_name, use_safetensors: bool):
|
||||
from transformers import BertForSequenceClassification
|
||||
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
bert_model = model_fn()
|
||||
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, *_ = get_components_func()
|
||||
with ColoInitContext(device=(get_current_device())):
|
||||
bert_model = model_builder()
|
||||
bert_model.config.save_pretrained(save_directory=(model_ckpt_dir.name))
|
||||
with shared_tempdir() as tempdir:
|
||||
pretrained_path = os.path.join(tempdir, 'pretrained')
|
||||
bert_model.config.save_pretrained(save_directory=pretrained_path)
|
||||
|
||||
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
bert_model = ZeroDDP(bert_model, gemini_manager)
|
||||
bert_model.train()
|
||||
# TODO(ver217): use boost api
|
||||
config_dict, *_ = search_chunk_configuration(bert_model, search_range_mb=1, search_interval_byte=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
bert_model = ZeroDDP(bert_model, gemini_manager)
|
||||
bert_model.train()
|
||||
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
if ckpt_io.coordinator.is_master():
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
|
||||
ckpt_io.save_model(bert_model, (model_ckpt_dir.name),
|
||||
ckpt_io.save_model(bert_model, (pretrained_path),
|
||||
True,
|
||||
True,
|
||||
'', (model_size / 3),
|
||||
use_safetensors=use_safetensors)
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(model_ckpt_dir.name)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=True, dtype=(torch.float32)),
|
||||
dist.barrier()
|
||||
new_bert_model = BertForSequenceClassification.from_pretrained(pretrained_path)
|
||||
check_state_dict_equal(bert_model.state_dict(only_rank_0=False, dtype=torch.float32),
|
||||
new_bert_model.state_dict(), False)
|
||||
model_ckpt_dir.cleanup()
|
||||
|
||||
|
||||
@parameterize('placement_policy', ['cuda', 'cpu'])
|
||||
@parameterize('model_name', ['gpt2', 'bert'])
|
||||
@parameterize('use_safetensors', [True, False])
|
||||
def exam_state_dict(placement_policy, model_name: str, use_safetensors: bool):
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, *_ = get_components_func()
|
||||
with ColoInitContext(device=(get_current_device())):
|
||||
model = model_builder()
|
||||
new_model = model_builder()
|
||||
config_dict, *_ = search_chunk_configuration(model, search_range_mb=1, search_interval_byte=100)
|
||||
chunk_manager = ChunkManager(config_dict)
|
||||
gemini_manager = GeminiManager(placement_policy, chunk_manager)
|
||||
model = ZeroDDP(model, gemini_manager)
|
||||
@parameterize('shard', [True, False])
|
||||
@parameterize('model_name', ['transformers_gpt'])
|
||||
def exam_state_dict(placement_policy, shard: bool, model_name: str):
|
||||
(model_fn, data_gen_fn, output_transform_fn, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
|
||||
criterion = lambda x: x.mean()
|
||||
plugin = GeminiPlugin(placement_policy=placement_policy)
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model.train()
|
||||
#new model
|
||||
new_config_dict, *_ = search_chunk_configuration(new_model, search_range_mb=1, search_interval_byte=100)
|
||||
new_chunk_manager = ChunkManager(new_config_dict)
|
||||
new_gemini_manager = GeminiManager(placement_policy, new_chunk_manager)
|
||||
new_model = ZeroDDP(new_model, new_gemini_manager)
|
||||
model = model_fn()
|
||||
new_model = model_fn()
|
||||
optimizer = HybridAdam(model.parameters(), lr=0.001)
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
new_optimizer = HybridAdam(new_model.parameters(), lr=0.001)
|
||||
new_model, new_optimizer, criterion, _, _ = booster.boost(new_model, new_optimizer, criterion)
|
||||
|
||||
model_ckpt_dir = tempfile.TemporaryDirectory()
|
||||
ckpt_io = GeminiCheckpointIO()
|
||||
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
|
||||
ckpt_io.save_model(model, (model_ckpt_dir.name),
|
||||
True,
|
||||
True,
|
||||
'epoch', (model_size / 3),
|
||||
use_safetensors=use_safetensors)
|
||||
data = data_gen_fn()
|
||||
data = {k: v.to('cuda') if torch.is_tensor(v) or 'Tensor' in v.__class__.__name__ else v for k, v in data.items()}
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
output_key = list(output.keys())[0]
|
||||
loss = criterion(output[output_key])
|
||||
|
||||
if ckpt_io.coordinator.is_master():
|
||||
ckpt_io.load_model(new_model, (model_ckpt_dir.name), strict=True)
|
||||
model_dict = model.state_dict(only_rank_0=True)
|
||||
new_model_dict = new_model.state_dict(only_rank_0=True)
|
||||
check_state_dict_equal(model_dict, new_model_dict, False)
|
||||
model_ckpt_dir.cleanup()
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
booster.save_model(model, model_ckpt_path)
|
||||
if not shard:
|
||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.unwrap().state_dict(only_rank_0=False),
|
||||
new_model.unwrap().state_dict(only_rank_0=False), False)
|
||||
if not shard:
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
@ -92,7 +100,7 @@ def run_dist(rank, world_size, port):
|
|||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [4, 4])
|
||||
@pytest.mark.parametrize('world_size', [2])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_gemini_ckpIO(world_size):
|
||||
spawn(run_dist, world_size)
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torchvision.models import resnet18
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin
|
||||
from colossalai.booster.plugin.low_level_zero_plugin import LowLevelZeroCheckpointIO
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.testing import (
|
||||
check_state_dict_equal,
|
||||
|
@ -20,7 +18,8 @@ from colossalai.testing import (
|
|||
|
||||
@clear_cache_before_run()
|
||||
@parameterize('stage', [2])
|
||||
def check_low_level_zero_checkpointIO(stage: int):
|
||||
@parameterize('shard', [True, False])
|
||||
def check_low_level_zero_checkpointIO(stage: int, shard: bool):
|
||||
plugin = LowLevelZeroPlugin(stage=stage, max_norm=1.0, initial_scale=32)
|
||||
booster = Booster(plugin=plugin)
|
||||
model = resnet18()
|
||||
|
@ -34,17 +33,25 @@ def check_low_level_zero_checkpointIO(stage: int):
|
|||
loss = criterion(output)
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.step()
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
# lr scheduler is tested in test_torch_ddp_checkpoint_io.py and low level zero does not change it, we can skip it here
|
||||
booster.save_model(model, model_ckpt_path, shard=shard)
|
||||
if not shard:
|
||||
# TODO(ver217): optimizer checkpointing is not supported for sharded checkpoint
|
||||
booster.save_optimizer(optimizer, optimizer_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
optimizer_ckpt_tempfile = tempfile.NamedTemporaryFile()
|
||||
ckpt_io = LowLevelZeroCheckpointIO()
|
||||
ckpt_io.save_optimizer(optimizer, optimizer_ckpt_tempfile.name)
|
||||
new_model = resnet18()
|
||||
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
|
||||
new_model, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
|
||||
|
||||
new_model = resnet18()
|
||||
new_optimizer = HybridAdam((new_model.parameters()), lr=0.001)
|
||||
_, new_optimizer, _, _, _ = booster.boost(new_model, new_optimizer)
|
||||
if ckpt_io.coordinator.is_master():
|
||||
ckpt_io.load_optimizer(new_optimizer, optimizer_ckpt_tempfile.name)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
booster.load_model(new_model, model_ckpt_path)
|
||||
check_state_dict_equal(model.state_dict(), new_model.state_dict(), False)
|
||||
if not shard:
|
||||
booster.load_optimizer(new_optimizer, optimizer_ckpt_path)
|
||||
check_state_dict_equal(optimizer.state_dict(), new_optimizer.state_dict(), False)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import SGD
|
||||
from torchvision.models import resnet18
|
||||
from utils import shared_tempdir
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
|
@ -35,11 +34,7 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tempdir:
|
||||
obj = [tempdir]
|
||||
dist.broadcast_object_list(obj, src=0)
|
||||
tempdir = obj[0] # use the same directory on all ranks
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
lr_scheduler_ckpt_path = f"{tempdir}/lr_scheduler"
|
||||
|
@ -66,8 +61,6 @@ def check_torch_ddp_checkpointIO(shard: bool):
|
|||
booster.load_lr_scheduler(new_scheduler, lr_scheduler_ckpt_path)
|
||||
check_state_dict_equal(scheduler.state_dict(), new_scheduler.state_dict(), False)
|
||||
|
||||
dist.barrier()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=(dict()), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
import tempfile
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from typing import Iterator
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
@contextmanager
|
||||
def shared_tempdir() -> Iterator[str]:
|
||||
"""
|
||||
A temporary directory that is shared across all processes.
|
||||
"""
|
||||
ctx_fn = tempfile.TemporaryDirectory if dist.get_rank() == 0 else nullcontext
|
||||
with ctx_fn() as tempdir:
|
||||
try:
|
||||
obj = [tempdir]
|
||||
dist.broadcast_object_list(obj, src=0)
|
||||
tempdir = obj[0] # use the same directory on all ranks
|
||||
yield tempdir
|
||||
finally:
|
||||
dist.barrier()
|
Loading…
Reference in New Issue