diff --git a/colossalai/utils/checkpoint/module_checkpoint.py b/colossalai/utils/checkpoint/module_checkpoint.py index c4de1c5ea..0cdb17d6c 100644 --- a/colossalai/utils/checkpoint/module_checkpoint.py +++ b/colossalai/utils/checkpoint/module_checkpoint.py @@ -5,7 +5,8 @@ import collections from torch.optim.lr_scheduler import CosineAnnealingLR as _CosineAnnealingLR from colossalai.utils.model.colo_init_context import colo_state_dict -def save_checkpoint(dire, + +def save_checkpoint(dire: str, epoch: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer = None, @@ -15,30 +16,21 @@ def save_checkpoint(dire, """save_checkpoint save a model, whose parameters are `ColoTensor`s. Args: - dire (_type_): _description_ - epoch (int): _description_ - model (torch.nn.Module): _description_ - optimizer (torch.optim.Optimizer, optional): _description_. Defaults to None. - lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): _description_. Defaults to None. + dire (str): directory to save the checkpoint files. + epoch (int): the number of epoch + model (torch.nn.Module): a torch module initialized by ColoInitContext + optimizer (torch.optim.Optimizer, optional): optimizers. Defaults to None. + lr_scheduler (torch.optim.lr_scheduler._LRScheduler, optional): lr schedule. Defaults to None. """ - model_state = { - 'epoch': epoch, - 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict) - } + model_state = {'epoch': epoch, 'model': colo_state_dict(model, state_dict_func=nn.Module.state_dict)} if dist.get_rank() == 0: torch.save(model_state, dire + '/epoch_{}_model.pth'.format(epoch)) lr_scheduler_dict = lr_scheduler.state_dict() lr_scheduler_dict['after_scheduler'] = lr_scheduler_dict['after_scheduler'].state_dict() - optim_state = { - 'epoch': epoch, - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler_dict - } + optim_state = {'epoch': epoch, 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler_dict} torch.save(optim_state, dire + '/epoch_{}_optim_rank_{}.pth'.format(epoch, dist.get_rank())) - - def load_checkpoint(dire, epoch: int, rank: int, @@ -64,10 +56,7 @@ def load_checkpoint(dire, optimizer.load_state_dict(optim_state['optimizer']) lr_scheduler_dict = optim_state['lr_scheduler'] after_scheduler_dict = lr_scheduler_dict['after_scheduler'] - lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR( - optimizer, - after_scheduler_dict['T_max'], - after_scheduler_dict['eta_min'], - after_scheduler_dict['last_epoch'] - ) + lr_scheduler_dict['after_scheduler'] = _CosineAnnealingLR(optimizer, after_scheduler_dict['T_max'], + after_scheduler_dict['eta_min'], + after_scheduler_dict['last_epoch']) lr_scheduler.load_state_dict(lr_scheduler_dict) diff --git a/tests/test_utils/test_colo_checkpoint.py b/tests/test_utils/test_colo_checkpoint.py index 6e7d4441d..48742fc18 100644 --- a/tests/test_utils/test_colo_checkpoint.py +++ b/tests/test_utils/test_colo_checkpoint.py @@ -1,21 +1,20 @@ from abc import ABC, abstractmethod -import os, sys, shutil +import os, shutil import torch import torch.nn as nn import pytest import copy -import operator -import colossalai -from colossalai.context.parallel_mode import ParallelMode +from functools import partial + import torch.multiprocessing as mp import torch.distributed as dist + +import colossalai from colossalai.testing import rerun_if_address_is_in_use from colossalai.utils.cuda import get_current_device from colossalai.utils import free_port from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup, ColoTensor -from colossalai.core import global_context as gpc -from functools import partial +from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, distspec, ProcessGroup from colossalai.nn.parallel.data_parallel import ColoDDP from colossalai.utils.checkpoint import save_checkpoint, load_checkpoint from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR @@ -46,15 +45,17 @@ class DummyDataGenerator(ABC): class DummyDataLoader(DummyDataGenerator): - batch_size = 128 - category = 16 - feature_size = 256 + + def __init__(self, batch_size, category, feature_size, length=10): + super().__init__(length) + self.batch_size = batch_size + self.category = category + self.feature_size = feature_size def generate(self): image_dict = {} - image_dict['pixel_values'] = torch.rand( - DummyDataLoader.batch_size, DummyDataLoader.feature_size, device=get_current_device()) * 2 - 1 - image_dict['label'] = torch.randint(DummyDataLoader.category, (DummyDataLoader.batch_size,), + image_dict['pixel_values'] = torch.rand(self.batch_size, self.feature_size, device=get_current_device()) * 2 - 1 + image_dict['label'] = torch.randint(self.category, (self.batch_size,), dtype=torch.int64, device=get_current_device()) return image_dict @@ -102,11 +103,15 @@ def remove(path): def run_checkpoint(init_spec_func, use_ddp, test_epoch, pg): - train_dataloader = DummyDataLoader(length=16) + batch = 3 + feature = 32 + category = 16 + train_dataloader = DummyDataLoader(batch, category, feature, length=16) with ColoInitContext(device=get_current_device()): - model = MLP(256, 16, 64) - model_reload = MLP(256, 16, 64) - model_ref = MLP(256, 16, 64) + model = MLP(feature, category) + model_reload = MLP(feature, category) + model_ref = MLP(feature, category) + model = model.cuda() model_reload = model_reload.cuda() model_ref = model_ref.cuda()