From f38006ea83963c48ca138bf82631038fbc692510 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 6 Jul 2022 17:22:03 +0800 Subject: [PATCH] [checkpoint] checkpoint for ColoTensor Model (#1196) --- colossalai/utils/checkpoint/__init__.py | 3 + .../utils/checkpoint/module_checkpoint.py | 73 ++++++ colossalai/utils/model/colo_init_context.py | 6 +- tests/test_utils/test_colo_checkpoint.py | 211 ++++++++++++++++++ 4 files changed, 292 insertions(+), 1 deletion(-) create mode 100644 colossalai/utils/checkpoint/__init__.py create mode 100644 colossalai/utils/checkpoint/module_checkpoint.py create mode 100644 tests/test_utils/test_colo_checkpoint.py diff --git a/colossalai/utils/checkpoint/__init__.py b/colossalai/utils/checkpoint/__init__.py new file mode 100644 index 000000000..1795b4ce3 --- /dev/null +++ b/colossalai/utils/checkpoint/__init__.py @@ -0,0 +1,3 @@ +from .module_checkpoint import save_checkpoint, load_checkpoint + +__all__ = ['save_checkpoint', 'load_checkpoint'] diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py new file mode 100644 index 000000000..c4de1c5ea --- /dev/null +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -0,0 +1,73 @@ +import torch +import torch.nn as nn +import torch.distributed as dist +import collections +from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR +from colossalai.utils.model.colo_init_context import colo_state_dict + +def save_checkpoint(dire, + epoch: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + *args, + **kwargs): + """save_checkpoint + save a model, whose parameters are `ColoTensor`s. + Args: + dire (_type_): _description_ + epoch (int): _description_ + model (torch.nn.Module): _description_ + optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. + """ + model_state = { + 'epoch': epoch, + 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict) + } + if dist.get_rank() == 0: + torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) + lr_scheduler_dict = lr_scheduler.state_dict() + lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict() + optim_state = { + 'epoch': epoch, + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler_dict + } + torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank())) + + + + +def load_checkpoint(dire, + epoch: int, + rank: int, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer = None, + lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, + *args, + **kwargs): + """load_checkpoint + load a model, whose parameters are `ColoTensor`s. + Args: + dire (_type_): _description_ + epoch (int): _description_ + rank (int): _description_ + model (torch.nn.Module): _description_ + optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. + """ + model_state = torch.load(dire + '/epoch_{}_model.pth'.format(epoch)) + model_state['model'] = collections.OrderedDict([(k.split('.', 1)[1], v) for k, v in model_state['model'].items()]) + model.load_state_dict(model_state['model']) + optim_state = torch.load(dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, rank)) + optimizer.load_state_dict(optim_state['optimizer']) + lr_scheduler_dict = optim_state['lr_scheduler'] + after_scheduler_dict = lr_scheduler_dict['after_scheduler'] + lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR( + optimizer, + after_scheduler_dict['T_max'], + after_scheduler_dict['eta_min'], + after_scheduler_dict['last_epoch'] + ) + lr_scheduler.load_state_dict(lr_scheduler_dict) diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index b7edac8f9..f6194a55a 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -38,15 +38,18 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di # build param to spec mapping mapping1 = dict() mapping2 = dict() + mapping3 = dict() # gather all params has_dist_parameter = False with torch.no_grad(): for param in self.parameters(): - if isinstance(param, ColoParameter) and param.has_compute_spec(): + if isinstance(param, ColoParameter): has_dist_parameter = True mapping1[id(param)] = copy(param.dist_spec) mapping2[id(param)] = copy(param.compute_spec) + mapping3[id(param)] = param.get_process_group() param.set_dist_spec(distspec.replicate()) + param.process_group = None # TODO: fix when keep_vars = True # when keep_vars = False, the state_dict_func will call detach to create @@ -64,6 +67,7 @@ def colo_state_dict(self, destination=None, prefix='', keep_vars=False, state_di if param_id in mapping1: dist_spec = mapping1[id(param)] compute_spec = mapping2[id(param)] + param.process_group = mapping3[id(param)] param.set_tensor_spec(dist_spec, compute_spec) return ret diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py new file mode 100644 index 000000000..6e7d4441d --- /dev/null +++ b/tests/test_utils/test_colo_checkpoint.py @@ -0,0 +1,211 @@ +from abc import ABC, abstractmethod +import os, sys, shutil +import torch +import torch.nn as nn +import pytest +import copy +import operator +import colossalai +from colossalai.context.parallel_mode import ParallelMode +import torch.multiprocessing as mp +import torch.distributed as dist +from colossalai.testing import rerun_if_address_is_in_use +from colossalai.utils.cuda import get_current_device +from colossalai.utils import free_port +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor +from colossalai.core import global_context as gpc +from functools import partial +from colossalai.nn.parallel.data_parallel import ColoDDP +from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint +from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR + + +class DummyDataGenerator(ABC): + + def __init__(self, length=10): + self.length = length + + @abstractmethod + def generate(self): + pass + + def __iter__(self): + self.step = 0 + return self + + def __next__(self): + if self.step < self.length: + self.step += 1 + return self.generate() + else: + raise StopIteration + + def __len__(self): + return self.length + + +class DummyDataLoader(DummyDataGenerator): + batch_size = 128 + category = 16 + feature_size = 256 + + def generate(self): + image_dict = {} + image_dict['pixel_values'] = torch.rand( + DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), + dtype=torch.int64, + device=get_current_device()) + return image_dict + + +class MLP(nn.Module): + + def __init__(self, in_features, out_features, hidden_features=None): + super().__init__() + if hidden_features is None: + hidden_features = out_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.activation = nn.ReLU() + + def forward(self, x): + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +def init_1d_row_for_linear_weight_spec(model, pg: ProcessGroup): + spec = (distspec.shard([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) + with DistSpecManager.no_grad(): + for n, p in model.named_parameters(): + if 'weight' in n: + p.set_process_group(pg) + p.set_tensor_spec(*spec) + + +def check_param_equal(model, torch_model): + for p, torch_p in zip(model.parameters(), torch_model.parameters()): + assert torch.allclose(torch_p, p, rtol=1e-3, atol=1e-1) + + +def remove(path): + """ param could either be relative or absolute. """ + if os.path.isfile(path) or os.path.islink(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + else: + raise ValueError("file {} is not a file or dir.".format(path)) + + +def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg): + train_dataloader = DummyDataLoader(length=16) + with ColoInitContext(device=get_current_device()): + model = MLP(256, 16, 64) + model_reload = MLP(256, 16, 64) + model_ref = MLP(256, 16, 64) + model = model.cuda() + model_reload = model_reload.cuda() + model_ref = model_ref.cuda() + if use_ddp: + model = ColoDDP(model, pg) + model_reload = ColoDDP(model_reload, pg) + model_ref = ColoDDP(model_ref, pg) + + criterion = torch.nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + optimizer_reload = torch.optim.Adam(model_reload.parameters(), + lr=0.001, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0) + optimizer_ref = torch.optim.Adam(model_ref.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=0) + + lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, total_steps=20, warmup_steps=5) + lr_scheduler_reload = CosineAnnealingWarmupLR(optimizer=optimizer_reload, total_steps=20, warmup_steps=5) + lr_scheduler_ref = CosineAnnealingWarmupLR(optimizer=optimizer_ref, total_steps=20, warmup_steps=5) + + init_spec_func(model, pg) + init_spec_func(model_ref, pg) + + for epoch in range(0, 20): + if epoch <= test_epoch: + for i, image_dict in enumerate(train_dataloader): + if use_ddp: + model.zero_grad() + else: + optimizer.zero_grad() + logits = model(image_dict['pixel_values']) + loss = criterion(logits, image_dict['label']) + if use_ddp: + model.backward(loss) + else: + loss.backward() + optimizer.step() + + if epoch == test_epoch: + for ref_p, p in zip(model_ref.parameters(), model.parameters()): + ref_p.data.copy_(p) + optimizer_ref = copy.deepcopy(optimizer) + lr_scheduler_ref = copy.deepcopy(lr_scheduler) + + check_param_equal(model, model_ref) + save_checkpoint('./checkpoint', epoch, model, optimizer, lr_scheduler) + dist.barrier() + else: + if epoch == test_epoch + 1: + load_checkpoint('./checkpoint', test_epoch, dist.get_rank(), model_reload, optimizer_reload, + lr_scheduler_reload) + init_spec_func(model_reload, pg) + for i, image_dict in enumerate(train_dataloader): + if use_ddp: + model_ref.zero_grad() + model_reload.zero_grad() + else: + optimizer_ref.zero_grad() + optimizer_reload.zero_grad() + logits_ref = model_ref(image_dict['pixel_values']) + logits_reload = model_reload(image_dict['pixel_values']) + loss_ref = criterion(logits_ref, image_dict['label']) + loss_reload = criterion(logits_reload, image_dict['label']) + if use_ddp: + model_ref.backward(loss_ref) + model_reload.backward(loss_reload) + else: + loss_ref.backward() + loss_reload.backward() + optimizer_ref.step() + optimizer_reload.step() + lr_scheduler.step() + + check_param_equal(model_ref, model_reload) + + +def run_dist(rank, world_size, port, use_ddp, test_epoch): + if use_ddp and world_size == 1: + return + tp_world_size = world_size // 2 if use_ddp else world_size + config = dict(parallel=dict(tensor=dict(mode="1d", size=tp_world_size),)) + colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + pg = ProcessGroup(tp_degree=world_size) + run_checkpoint(init_1d_row_for_linear_weight_spec, use_ddp, test_epoch, pg) + + +@pytest.mark.dist +@pytest.mark.parametrize('world_size', [4]) +@pytest.mark.parametrize('use_ddp', [True]) +@pytest.mark.parametrize('test_epoch', [1, 2, 3]) +@rerun_if_address_is_in_use() +def test_checkpoint(world_size, use_ddp, test_epoch): + if not os.path.isdir('./checkpoint'): + os.mkdir('./checkpoint') + run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp, test_epoch=test_epoch) + mp.spawn(run_func, nprocs=world_size) + remove('./checkpoint') + + +if __name__ == '__main__': + test_checkpoint(4, True, 1)