[checkpoint] make unitest faster (#1217)

pull/1204/head
Jiarui Fang 2 years ago committed by GitHub
parent f38006ea83
commit 52736205d9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,8 @@ import collections
from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR
from colossalai.utils.model.colo_init_context import colo_state_dict from colossalai.utils.model.colo_init_context import colo_state_dict
def save_checkpoint(dire,
def save_checkpoint(dire: str,
epoch: int, epoch: int,
model: torch.nn.Module, model: torch.nn.Module,
optimizer: torch.optim.Optimizer = None, optimizer: torch.optim.Optimizer = None,
@ -15,30 +16,21 @@ def save_checkpoint(dire,
"""save_checkpoint """save_checkpoint
save a model, whose parameters are `ColoTensor`s. save a model, whose parameters are `ColoTensor`s.
Args: Args:
dire (_type_): _description_ dire (str): directory to save the checkpoint files.
epoch (int): _description_ epoch (int): the number of epoch
model (torch.nn.Module): _description_ model (torch.nn.Module): a torch module initialized by ColoInitContext
optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None.
lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None.
""" """
model_state = { model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)}
'epoch': epoch,
'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)
}
if dist.get_rank() == 0: if dist.get_rank() == 0:
torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch))
lr_scheduler_dict = lr_scheduler.state_dict() lr_scheduler_dict = lr_scheduler.state_dict()
lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict() lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict()
optim_state = { optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler_dict}
'epoch': epoch,
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler_dict
}
torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank())) torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank()))
def load_checkpoint(dire, def load_checkpoint(dire,
epoch: int, epoch: int,
rank: int, rank: int,
@ -64,10 +56,7 @@ def load_checkpoint(dire,
optimizer.load_state_dict(optim_state['optimizer']) optimizer.load_state_dict(optim_state['optimizer'])
lr_scheduler_dict = optim_state['lr_scheduler'] lr_scheduler_dict = optim_state['lr_scheduler']
after_scheduler_dict = lr_scheduler_dict['after_scheduler'] after_scheduler_dict = lr_scheduler_dict['after_scheduler']
lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR( lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR(optimizer, after_scheduler_dict['T_max'],
optimizer,
after_scheduler_dict['T_max'],
after_scheduler_dict['eta_min'], after_scheduler_dict['eta_min'],
after_scheduler_dict['last_epoch'] after_scheduler_dict['last_epoch'])
)
lr_scheduler.load_state_dict(lr_scheduler_dict) lr_scheduler.load_state_dict(lr_scheduler_dict)

@ -1,21 +1,20 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import os, sys, shutil import os, shutil
import torch import torch
import torch.nn as nn import torch.nn as nn
import pytest import pytest
import copy import copy
import operator from functools import partial
import colossalai
from colossalai.context.parallel_mode import ParallelMode
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed as dist import torch.distributed as dist
import colossalai
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils.cuda import get_current_device from colossalai.utils.cuda import get_current_device
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.utils.model.colo_init_context import ColoInitContext from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup
from colossalai.core import global_context as gpc
from functools import partial
from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.nn.parallel.data_parallel import ColoDDP
from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
@ -46,15 +45,17 @@ class DummyDataGenerator(ABC):
class DummyDataLoader(DummyDataGenerator): class DummyDataLoader(DummyDataGenerator):
batch_size = 128
category = 16 def __init__(self, batch_size, category, feature_size, length=10):
feature_size = 256 super().__init__(length)
self.batch_size = batch_size
self.category = category
self.feature_size = feature_size
def generate(self): def generate(self):
image_dict = {} image_dict = {}
image_dict['pixel_values'] = torch.rand( image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1
DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1 image_dict['label'] = torch.randint(self.category, (self.batch_size,),
image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,),
dtype=torch.int64, dtype=torch.int64,
device=get_current_device()) device=get_current_device())
return image_dict return image_dict
@ -102,11 +103,15 @@ def remove(path):
def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg): def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg):
train_dataloader = DummyDataLoader(length=16) batch = 3
feature = 32
category = 16
train_dataloader = DummyDataLoader(batch, category, feature, length=16)
with ColoInitContext(device=get_current_device()): with ColoInitContext(device=get_current_device()):
model = MLP(256, 16, 64) model = MLP(feature, category)
model_reload = MLP(256, 16, 64) model_reload = MLP(feature, category)
model_ref = MLP(256, 16, 64) model_ref = MLP(feature, category)
model = model.cuda() model = model.cuda()
model_reload = model_reload.cuda() model_reload = model_reload.cuda()
model_ref = model_ref.cuda() model_ref = model_ref.cuda()

Loading…
Cancel
Save