ColossalAI/tests/test_trainer/test_pipeline/test_schedule.py

52 lines
1.2 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os.path as osp
import pytest
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.initialize import initialize
from colossalai.logging import get_dist_logger
NUM_BATCH = 128
BATCH_SIZE = 32
SEQ_LENGTH = 128
HIDDEN_SIZE = 512
DIR_PATH = osp.dirname(osp.realpath(__file__))
CONFIG_PATH = osp.join(DIR_PATH, '../configs/pipeline_vanilla_resnet.py')
@pytest.mark.skip("This test should be invoked using the test.sh provided")
@pytest.mark.dist
def test_schedule():
engine, train_dataloader, test_dataloader = initialize(CONFIG_PATH)
logger = get_dist_logger()
model = engine.model
optimizer = engine.optimizer
criterion = engine.criterion
schedule = engine._schedule
output, label, loss = schedule.forward_backward_step(
data_iter=iter(train_dataloader),
model=model,
optimizer=optimizer,
criterion=criterion,
forward_only=False
)
schedule.optimizer_step(model, optimizer)
if gpc.is_last_rank(ParallelMode.PIPELINE):
logger.info('losses: {}'.format(loss))
gpc.destroy()
logger.info('training finished')
if __name__ == '__main__':
test_schedule()