2021-10-28 16:21:23 +00:00
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import argparse
import pprint
import random
from pathlib import Path
from typing import Callable , Iterable , Optional , Union
2021-11-18 11:45:06 +00:00
from typing import Tuple
2021-10-28 16:21:23 +00:00
import numpy as np
import torch
from torch . utils . data import DataLoader
from colossalai . engine import AMP_TYPE , NoPipelineSchedule , PipelineSchedule
2021-11-18 11:45:06 +00:00
from colossalai . engine import Engine
2021-10-28 16:21:23 +00:00
from colossalai . logging import get_global_dist_logger , init_global_dist_logger
from colossalai . nn import DataParallelSampler
from colossalai . nn . model . base_model import BaseModel
from . builder import ( ModelInitializer , build_dataset , build_loss ,
2021-11-18 11:45:06 +00:00
build_model , build_optimizer ,
build_optimizer_wrapper , build_schedule )
2021-10-28 16:21:23 +00:00
from . context import Config , ParallelMode
from . core import global_context as gpc
from . utils import get_current_device , sync_model_param_in_dp
def parse_args ( ) :
''' Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration , host , port , world size , local rank , backend for torch . distributed .
: return : call the parse arguments function
: rtype : Namespace
'''
parser = argparse . ArgumentParser ( )
parser . add_argument ( ' --config ' , type = str , help = ' path to the config file ' )
parser . add_argument ( ' --host ' ,
type = str ,
default = None ,
help = ' the master address for distributed training ' )
parser . add_argument ( ' --port ' ,
type = str ,
default = None ,
help = ' the master port for distributed training ' )
parser . add_argument ( ' --world_size ' , type = int , help = ' world size for ' )
parser . add_argument ( ' --local_rank ' ,
type = int ,
help = ' rank for the default process group ' )
parser . add_argument ( ' --backend ' ,
type = str ,
default = ' nccl ' ,
help = ' backend for torch.distributed ' )
return parser . parse_args ( )
def init_dist ( config : Union [ str , dict ] = None ,
local_rank : int = None ,
world_size : int = None ,
host : str = None ,
port : str = None ,
backend : str = None ) :
''' This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
Then initialize and set distributed environment by calling global_context ' s functions.
: param config : config file or config file path are both acceptable
: type config : Union [ str , dict ] , optional
: param local_rank : rank for the default process group , defaults to None
: type local_rank : int , optional
: param world_size : world size of GPUs , defaults to None
: type world_size : int , optional
: param host : the master address for distributed training , defaults to None
: type host : str , optional
: param port : the master port for distributed training , defaults to None
: type port : str , optional
: param backend : backend for torch . distributed , defaults to None
: type backend : str , optional
: raises Exception : raise exception when config type is wrong
'''
args = [ config , local_rank , world_size , host , port , backend ]
arg_given = [ arg is not None for arg in args ]
if not all ( arg_given ) :
args = parse_args ( )
if config is None :
config = args . config
if local_rank is None :
local_rank = args . local_rank
if world_size is None :
world_size = args . world_size
if host is None :
host = args . host
if port is None :
port = args . port
if backend is None :
backend = args . backend
args = Config (
dict ( config = config ,
host = host ,
port = port ,
world_size = world_size ,
local_rank = local_rank ,
backend = backend ) )
# set distributed settings
dist_args = Config (
dict ( local_rank = args . local_rank ,
world_size = args . world_size ,
backend = args . backend ) )
gpc . set_dist_args ( dist_args )
# set config
if isinstance ( args . config , dict ) :
cfg = args . config
elif isinstance ( args . config , ( str , Path ) ) :
cfg = Config . from_file ( args . config )
else :
raise Exception ( ' Config type error: {} ' . format ( type ( args . config ) ) )
gpc . load_config ( cfg )
# init dist groups
gpc . init_global_dist ( args . host , args . port )
gpc . init_parallel_groups ( )
# init dist logger
init_global_dist_logger ( )
# set cuda device
if torch . cuda . is_available ( ) :
gpc . set_device ( )
def get_dataloader ( dataset , seed = 1024 , add_sampler_if_possible = False , * * kwargs ) :
''' Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
. . note : when pipeline parallel is enabled , shuffle cannot be True
as it will result in mismatch between input data on the 1 st
stage and label on the last stage
: param dataset : a : class : utils . data . dataset dataset
: param seed : random worker seed , defaults to 1024
: type seed : int , optional
: param add_sampler_if_possible : [ description ] , defaults to False
: type add_sampler_if_possible : bool , optional
: return : a : class : utils . data . dataset dataloader
: rtype : torch . utils . data . dataset
'''
_kwargs = kwargs . copy ( )
if ' shuffle ' in _kwargs :
shuffle = _kwargs . pop ( ' shuffle ' )
else :
shuffle = False
if add_sampler_if_possible and gpc . is_initialized ( ParallelMode . DATA ) and gpc . get_world_size ( ParallelMode . DATA ) > 1 :
sampler = DataParallelSampler ( dataset , shuffle = shuffle )
else :
sampler = None
# Deterministic dataloader
def seed_worker ( worker_id ) :
worker_seed = seed
np . random . seed ( worker_seed )
torch . manual_seed ( worker_seed )
random . seed ( worker_seed )
if sampler is None :
return DataLoader ( dataset ,
worker_init_fn = seed_worker ,
shuffle = shuffle ,
* * _kwargs )
else :
return DataLoader ( dataset ,
sampler = sampler ,
worker_init_fn = seed_worker ,
* * _kwargs )
def initialize ( config : Union [ str , dict ] = None ,
local_rank : int = None ,
world_size : int = None ,
host : str = None ,
port : str = None ,
backend : str = None ,
train_dataloader : Optional [ Union [ Iterable , Callable ] ] = None ,
test_dataloader : Optional [ Union [ Iterable , Callable ] ] = None ,
2021-11-18 11:45:06 +00:00
) - > Tuple [ Engine , DataLoader , DataLoader ] :
2021-10-28 16:21:23 +00:00
''' Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
: param config : config file or config file path are both acceptable
: type config : Union [ str , dict ] , optional
: param local_rank : rank for the default process group , defaults to None
: type local_rank : int , optional
: param world_size : world size of GPUs , defaults to None
: type world_size : int , optional
: param host : the master address for distributed training , defaults to None
: type host : str , optional
: param port : the master port for distributed training , defaults to None
: type port : str , optional
: param backend : backend for torch . distributed , defaults to None
: type backend : str , optional
: param train_dataloader : If None , the config is used to build a dataloder ; Else , it should be a dataloader object or a function with no arguments which can build a dataloader , defaults to None
: type train_dataloader : Optional [ Union [ Iterable , Callable ] ] , optional
: param test_dataloader : If None , the config is used to build a dataloder ; Else , it should be a dataloader object or a function with no arguments which can build a dataloader , defaults to None
: type test_dataloader : Optional [ Union [ Iterable , Callable ] ] , optional
2021-11-18 11:45:06 +00:00
: return : ( engine , train_dataloader , test_dataloader , criterion )
2021-10-28 16:21:23 +00:00
: rtype : tuple
'''
# initialize distributed environment
init_dist ( config = config ,
local_rank = local_rank ,
world_size = world_size ,
host = host ,
port = port ,
backend = backend )
# init logger
logger = get_global_dist_logger ( )
logger . info ( f ' Distributed environment is initialized, '
f ' data parallel size: { gpc . data_parallel_size } , pipeline parallel size: { gpc . pipeline_parallel_size } , '
f ' tensor parallel size: { gpc . tensor_parallel_size } ' , ranks = [ 0 ] )
# print config
logger . info ( f " \n ========== Your Config ======== \n "
f " { pprint . pformat ( gpc . config ) } \n "
f " ================================ " , ranks = [ 0 ] )
# cudnn
cudnn_benchmark = gpc . config . get ( ' cudnn_benchmark ' , True )
cudnn_deterministic = gpc . config . get ( ' cudnn_deterministic ' , False )
torch . backends . cudnn . benchmark = cudnn_benchmark
torch . backends . cudnn . deterministic = cudnn_deterministic
logger . info (
f " cuDNN benchmark = { cudnn_benchmark } , deterministic = { cudnn_deterministic } " , ranks = [ 0 ] )
# set seed, cuda seed is only set when cuda is avail
gpc . set_seed ( )
# return_items = list()
# check fp16 and zero
should_convert_model_to_half = False
should_wrap_fp16_optimizer = False
should_wrap_zero_optimizer_level_2_3 = False
if hasattr ( gpc . config , ' fp16 ' ) :
fp16_mode = gpc . config . fp16 . mode
if fp16_mode == AMP_TYPE . PARALLEL :
should_convert_model_to_half = True
should_wrap_fp16_optimizer = True
if hasattr ( gpc . config , ' zero ' ) :
should_wrap_zero_optimizer_level_2_3 = True
zero_type = gpc . config . zero . type
if zero_type in [ ' ZeroRedundancyOptimizer_Level_2 ' , ' ZeroRedundancyOptimizer_Level_3 ' ] :
should_convert_model_to_half = True
assert not should_wrap_fp16_optimizer , \
' AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3 '
# build model
logger . info ( ' Building model ... ' , ranks = [ 0 ] )
assert hasattr (
gpc . config , ' model ' ) , " Build error: configuration ' model ' is missing "
if gpc . pipeline_parallel_size > 1 :
model = ModelInitializer ( gpc . config . model , 1 , verbose = True )
model = model . model_initialize ( )
else :
model = build_model ( gpc . config . model )
if isinstance ( model , BaseModel ) :
model . build_from_cfg ( )
model = model . to ( get_current_device ( ) )
sync_model_param_in_dp ( model )
logger . info ( ' Model is created ' , ranks = [ 0 ] )
if should_convert_model_to_half :
model = model . half ( )
logger . info ( " Model is cast to fp16 " , ranks = [ 0 ] )
# training data
if callable ( train_dataloader ) :
logger . info (
f ' Build train data loader from { train_dataloader } ' , ranks = [ 0 ] )
train_dataloader = train_dataloader ( )
if train_dataloader is None and hasattr ( gpc . config , ' train_data ' ) :
logger . info ( ' Preparing data ... ' , ranks = [ 0 ] )
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
train_dataset = build_dataset ( gpc . config . train_data . dataset )
logger . info ( ' Train dataset is ready. ' , ranks = [ 0 ] )
train_dataloader = get_dataloader ( train_dataset ,
gpc . config . get ( ' seed ' , 1024 ) ,
True ,
* * gpc . config . train_data . dataloader ,
)
logger . info (
f ' Loaded { len ( train_dataset ) } samples in { len ( train_dataloader ) } batches for training ' , ranks = [ 0 ] )
if callable ( test_dataloader ) :
logger . info (
f ' Build test data loader from { test_dataloader } ' , ranks = [ 0 ] )
test_dataloader = test_dataloader ( )
# testing data, allowed to be None
if test_dataloader is None and hasattr ( gpc . config , ' test_data ' ) :
test_dataset = build_dataset ( gpc . config . test_data . dataset )
test_dataloader = get_dataloader (
test_dataset , add_sampler_if_possible = True , * * gpc . config . test_data . dataloader )
logger . info (
f ' Loaded { len ( test_dataset ) } samples in { len ( test_dataloader ) } batches for testing ' , ranks = [ 0 ] )
# build loss function
assert hasattr ( gpc . config , ' loss ' ) , \
' Build error: configuration \' loss \' is missing. '
criterion = build_loss ( gpc . config . loss )
logger . info ( ' Loss function is created ' , ranks = [ 0 ] )
# build optimizer
assert hasattr ( gpc . config , ' optimizer ' ) , \
" Build error: configuration ' optimizer ' is missing. "
optim_type = gpc . config . optimizer . type
is_pytorch_native_zero_level_1 = optim_type == ' ZeroRedundancyOptimizer '
if is_pytorch_native_zero_level_1 :
original_cfg_copy = gpc . config . optimizer . copy ( )
original_cfg_copy . pop ( ' type ' )
cfg = dict ( type = optim_type , process_group = gpc . get_group (
ParallelMode . DATA ) , * * original_cfg_copy )
optimizer = build_optimizer ( cfg , model )
else :
optimizer = build_optimizer ( gpc . config . optimizer , model )
if should_wrap_zero_optimizer_level_2_3 :
optimizer = build_optimizer_wrapper ( gpc . config . zero , optimizer , model )
if should_wrap_fp16_optimizer :
# replace the field mode with type
fp16_cfg = gpc . config . fp16 . copy ( )
amp_type = fp16_cfg . pop ( ' mode ' )
assert amp_type == AMP_TYPE . PARALLEL , ' FP Optimizer should only be used for AMP_TYPE.PARALLEL '
fp16_cfg [ ' type ' ] = ' FP16Optimizer '
optimizer = build_optimizer_wrapper ( fp16_cfg , optimizer )
logger . info ( ' Optimizer is created ' , ranks = [ 0 ] )
2021-11-18 11:45:06 +00:00
# build schedule and engine
2021-10-28 16:21:23 +00:00
if hasattr ( gpc . config , ' fp16 ' ) :
amp_type = gpc . config . fp16 . mode
amp_cfg = gpc . config . fp16 . copy ( )
amp_cfg . pop ( ' mode ' )
else :
amp_type = None
amp_cfg = None
2021-11-18 11:45:06 +00:00
engine_cfg = gpc . config . get ( ' engine ' , dict ( ) )
schedule_cfg = engine_cfg . pop ( ' schedule ' , None )
schedule_type = None
if schedule_cfg is not None :
schedule_type = schedule_cfg . get ( ' type ' , None )
if schedule_type is not None :
# run customized schedule
schedule_cfg [ ' amp_type ' ] = amp_type
schedule_cfg [ ' amp_config ' ] = amp_cfg
schedule = build_schedule ( schedule_cfg )
elif gpc . is_initialized ( ParallelMode . PIPELINE ) and gpc . get_world_size ( ParallelMode . PIPELINE ) > 1 :
assert schedule_cfg is not None , \
" Config ' engine.schedule ' not found in your configuration file for pipeline parallel training "
2021-10-28 16:21:23 +00:00
schedule = PipelineSchedule (
2021-11-18 11:45:06 +00:00
amp_type = amp_type , amp_config = amp_cfg , * * schedule_cfg . copy ( ) )
2021-10-28 16:21:23 +00:00
else :
schedule = NoPipelineSchedule ( amp_type = amp_type , amp_config = amp_cfg )
2021-11-18 11:45:06 +00:00
engine = Engine (
model = model ,
optimizer = optimizer ,
criterion = criterion ,
step_schedule = schedule ,
* * gpc . config . get ( ' engine ' , dict ( ) )
)
return engine , train_dataloader , test_dataloader