mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
380 lines
15 KiB
380 lines
15 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import argparse
|
|
import pprint
|
|
import random
|
|
from pathlib import Path
|
|
from typing import Callable, Iterable, Optional, Union
|
|
from typing import Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
|
|
from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule
|
|
from colossalai.engine import Engine
|
|
from colossalai.logging import get_global_dist_logger, init_global_dist_logger
|
|
from colossalai.nn import DataParallelSampler
|
|
from colossalai.nn.model.base_model import BaseModel
|
|
from .builder import (ModelInitializer, build_dataset, build_loss,
|
|
build_model, build_optimizer,
|
|
build_optimizer_wrapper, build_schedule)
|
|
from .context import Config, ParallelMode
|
|
from .core import global_context as gpc
|
|
from .utils import get_current_device, sync_model_param_in_dp
|
|
|
|
|
|
def parse_args():
|
|
'''Reads user command line and uses an argument parser to parse the input arguments.
|
|
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
|
|
|
|
:return: call the parse arguments function
|
|
:rtype: Namespace
|
|
'''
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--config', type=str, help='path to the config file')
|
|
parser.add_argument('--host',
|
|
type=str,
|
|
default=None,
|
|
help='the master address for distributed training')
|
|
parser.add_argument('--port',
|
|
type=str,
|
|
default=None,
|
|
help='the master port for distributed training')
|
|
parser.add_argument('--world_size', type=int, help='world size for ')
|
|
parser.add_argument('--local_rank',
|
|
type=int,
|
|
help='rank for the default process group')
|
|
parser.add_argument('--backend',
|
|
type=str,
|
|
default='nccl',
|
|
help='backend for torch.distributed')
|
|
return parser.parse_args()
|
|
|
|
|
|
def init_dist(config: Union[str, dict] = None,
|
|
local_rank: int = None,
|
|
world_size: int = None,
|
|
host: str = None,
|
|
port: str = None,
|
|
backend: str = None):
|
|
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
|
|
Then initialize and set distributed environment by calling global_context's functions.
|
|
|
|
:param config: config file or config file path are both acceptable
|
|
:type config: Union[str, dict], optional
|
|
:param local_rank: rank for the default process group, defaults to None
|
|
:type local_rank: int, optional
|
|
:param world_size: world size of GPUs, defaults to None
|
|
:type world_size: int, optional
|
|
:param host: the master address for distributed training, defaults to None
|
|
:type host: str, optional
|
|
:param port: the master port for distributed training, defaults to None
|
|
:type port: str, optional
|
|
:param backend: backend for torch.distributed, defaults to None
|
|
:type backend: str, optional
|
|
:raises Exception: raise exception when config type is wrong
|
|
'''
|
|
args = [config, local_rank, world_size, host, port, backend]
|
|
arg_given = [arg is not None for arg in args]
|
|
|
|
if not all(arg_given):
|
|
args = parse_args()
|
|
|
|
if config is None:
|
|
config = args.config
|
|
if local_rank is None:
|
|
local_rank = args.local_rank
|
|
if world_size is None:
|
|
world_size = args.world_size
|
|
if host is None:
|
|
host = args.host
|
|
if port is None:
|
|
port = args.port
|
|
if backend is None:
|
|
backend = args.backend
|
|
args = Config(
|
|
dict(config=config,
|
|
host=host,
|
|
port=port,
|
|
world_size=world_size,
|
|
local_rank=local_rank,
|
|
backend=backend))
|
|
|
|
# set distributed settings
|
|
dist_args = Config(
|
|
dict(local_rank=args.local_rank,
|
|
world_size=args.world_size,
|
|
backend=args.backend))
|
|
|
|
gpc.set_dist_args(dist_args)
|
|
|
|
# set config
|
|
if isinstance(args.config, dict):
|
|
cfg = args.config
|
|
elif isinstance(args.config, (str, Path)):
|
|
cfg = Config.from_file(args.config)
|
|
else:
|
|
raise Exception('Config type error: {}'.format(type(args.config)))
|
|
gpc.load_config(cfg)
|
|
|
|
# init dist groups
|
|
gpc.init_global_dist(args.host, args.port)
|
|
gpc.init_parallel_groups()
|
|
|
|
# init dist logger
|
|
init_global_dist_logger()
|
|
|
|
# set cuda device
|
|
if torch.cuda.is_available():
|
|
gpc.set_device()
|
|
|
|
|
|
def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs):
|
|
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
|
|
|
|
.. note: when pipeline parallel is enabled, shuffle cannot be True
|
|
as it will result in mismatch between input data on the 1st
|
|
stage and label on the last stage
|
|
|
|
:param dataset: a :class:utils.data.dataset dataset
|
|
:param seed: random worker seed, defaults to 1024
|
|
:type seed: int, optional
|
|
:param add_sampler_if_possible: [description], defaults to False
|
|
:type add_sampler_if_possible: bool, optional
|
|
:return: a :class:utils.data.dataset dataloader
|
|
:rtype: torch.utils.data.dataset
|
|
'''
|
|
_kwargs = kwargs.copy()
|
|
if 'shuffle' in _kwargs:
|
|
shuffle = _kwargs.pop('shuffle')
|
|
else:
|
|
shuffle = False
|
|
|
|
if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
|
|
sampler = DataParallelSampler(dataset, shuffle=shuffle)
|
|
else:
|
|
sampler = None
|
|
|
|
# Deterministic dataloader
|
|
def seed_worker(worker_id):
|
|
worker_seed = seed
|
|
np.random.seed(worker_seed)
|
|
torch.manual_seed(worker_seed)
|
|
random.seed(worker_seed)
|
|
|
|
if sampler is None:
|
|
return DataLoader(dataset,
|
|
worker_init_fn=seed_worker,
|
|
shuffle=shuffle,
|
|
**_kwargs)
|
|
else:
|
|
return DataLoader(dataset,
|
|
sampler=sampler,
|
|
worker_init_fn=seed_worker,
|
|
**_kwargs)
|
|
|
|
|
|
def initialize(config: Union[str, dict] = None,
|
|
local_rank: int = None,
|
|
world_size: int = None,
|
|
host: str = None,
|
|
port: str = None,
|
|
backend: str = None,
|
|
train_dataloader: Optional[Union[Iterable, Callable]] = None,
|
|
test_dataloader: Optional[Union[Iterable, Callable]] = None,
|
|
) -> Tuple[Engine, DataLoader, DataLoader]:
|
|
'''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config).
|
|
|
|
:param config: config file or config file path are both acceptable
|
|
:type config: Union[str, dict], optional
|
|
:param local_rank: rank for the default process group, defaults to None
|
|
:type local_rank: int, optional
|
|
:param world_size: world size of GPUs, defaults to None
|
|
:type world_size: int, optional
|
|
:param host: the master address for distributed training, defaults to None
|
|
:type host: str, optional
|
|
:param port: the master port for distributed training, defaults to None
|
|
:type port: str, optional
|
|
:param backend: backend for torch.distributed, defaults to None
|
|
:type backend: str, optional
|
|
:param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
|
:type train_dataloader: Optional[Union[Iterable, Callable]], optional
|
|
:param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None
|
|
:type test_dataloader: Optional[Union[Iterable, Callable]], optional
|
|
:return: (engine, train_dataloader, test_dataloader, criterion)
|
|
:rtype: tuple
|
|
'''
|
|
# initialize distributed environment
|
|
init_dist(config=config,
|
|
local_rank=local_rank,
|
|
world_size=world_size,
|
|
host=host,
|
|
port=port,
|
|
backend=backend)
|
|
|
|
# init logger
|
|
logger = get_global_dist_logger()
|
|
logger.info(f'Distributed environment is initialized, '
|
|
f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, '
|
|
f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0])
|
|
|
|
# print config
|
|
logger.info(f"\n========== Your Config ========\n"
|
|
f"{pprint.pformat(gpc.config)}\n"
|
|
f"================================", ranks=[0])
|
|
|
|
# cudnn
|
|
cudnn_benchmark = gpc.config.get('cudnn_benchmark', True)
|
|
cudnn_deterministic = gpc.config.get('cudnn_deterministic', False)
|
|
torch.backends.cudnn.benchmark = cudnn_benchmark
|
|
torch.backends.cudnn.deterministic = cudnn_deterministic
|
|
logger.info(
|
|
f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0])
|
|
|
|
# set seed, cuda seed is only set when cuda is avail
|
|
gpc.set_seed()
|
|
|
|
# return_items = list()
|
|
|
|
# check fp16 and zero
|
|
should_convert_model_to_half = False
|
|
should_wrap_fp16_optimizer = False
|
|
should_wrap_zero_optimizer_level_2_3 = False
|
|
|
|
if hasattr(gpc.config, 'fp16'):
|
|
fp16_mode = gpc.config.fp16.mode
|
|
if fp16_mode == AMP_TYPE.PARALLEL:
|
|
should_convert_model_to_half = True
|
|
should_wrap_fp16_optimizer = True
|
|
|
|
if hasattr(gpc.config, 'zero'):
|
|
should_wrap_zero_optimizer_level_2_3 = True
|
|
zero_type = gpc.config.zero.type
|
|
if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']:
|
|
should_convert_model_to_half = True
|
|
assert not should_wrap_fp16_optimizer, \
|
|
'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3'
|
|
|
|
# build model
|
|
logger.info('Building model ...', ranks=[0])
|
|
assert hasattr(
|
|
gpc.config, 'model'), "Build error: configuration 'model' is missing"
|
|
if gpc.pipeline_parallel_size > 1:
|
|
model = ModelInitializer(gpc.config.model, 1, verbose=True)
|
|
model = model.model_initialize()
|
|
else:
|
|
model = build_model(gpc.config.model)
|
|
if isinstance(model, BaseModel):
|
|
model.build_from_cfg()
|
|
model = model.to(get_current_device())
|
|
sync_model_param_in_dp(model)
|
|
logger.info('Model is created', ranks=[0])
|
|
|
|
if should_convert_model_to_half:
|
|
model = model.half()
|
|
logger.info("Model is cast to fp16", ranks=[0])
|
|
|
|
# training data
|
|
if callable(train_dataloader):
|
|
logger.info(
|
|
f'Build train data loader from {train_dataloader}', ranks=[0])
|
|
train_dataloader = train_dataloader()
|
|
if train_dataloader is None and hasattr(gpc.config, 'train_data'):
|
|
logger.info('Preparing data ...', ranks=[0])
|
|
# assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing."
|
|
train_dataset = build_dataset(gpc.config.train_data.dataset)
|
|
logger.info('Train dataset is ready.', ranks=[0])
|
|
|
|
train_dataloader = get_dataloader(train_dataset,
|
|
gpc.config.get('seed', 1024),
|
|
True,
|
|
**gpc.config.train_data.dataloader,
|
|
)
|
|
logger.info(
|
|
f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0])
|
|
|
|
if callable(test_dataloader):
|
|
logger.info(
|
|
f'Build test data loader from {test_dataloader}', ranks=[0])
|
|
test_dataloader = test_dataloader()
|
|
# testing data, allowed to be None
|
|
if test_dataloader is None and hasattr(gpc.config, 'test_data'):
|
|
test_dataset = build_dataset(gpc.config.test_data.dataset)
|
|
test_dataloader = get_dataloader(
|
|
test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader)
|
|
logger.info(
|
|
f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0])
|
|
|
|
# build loss function
|
|
assert hasattr(gpc.config, 'loss'), \
|
|
'Build error: configuration \'loss\' is missing.'
|
|
criterion = build_loss(gpc.config.loss)
|
|
logger.info('Loss function is created', ranks=[0])
|
|
|
|
# build optimizer
|
|
assert hasattr(gpc.config, 'optimizer'), \
|
|
"Build error: configuration 'optimizer' is missing."
|
|
optim_type = gpc.config.optimizer.type
|
|
is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer'
|
|
if is_pytorch_native_zero_level_1:
|
|
original_cfg_copy = gpc.config.optimizer.copy()
|
|
original_cfg_copy.pop('type')
|
|
cfg = dict(type=optim_type, process_group=gpc.get_group(
|
|
ParallelMode.DATA), **original_cfg_copy)
|
|
optimizer = build_optimizer(cfg, model)
|
|
else:
|
|
optimizer = build_optimizer(gpc.config.optimizer, model)
|
|
|
|
if should_wrap_zero_optimizer_level_2_3:
|
|
optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model)
|
|
|
|
if should_wrap_fp16_optimizer:
|
|
# replace the field mode with type
|
|
fp16_cfg = gpc.config.fp16.copy()
|
|
amp_type = fp16_cfg.pop('mode')
|
|
assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL'
|
|
fp16_cfg['type'] = 'FP16Optimizer'
|
|
optimizer = build_optimizer_wrapper(fp16_cfg, optimizer)
|
|
logger.info('Optimizer is created', ranks=[0])
|
|
|
|
# build schedule and engine
|
|
if hasattr(gpc.config, 'fp16'):
|
|
amp_type = gpc.config.fp16.mode
|
|
amp_cfg = gpc.config.fp16.copy()
|
|
amp_cfg.pop('mode')
|
|
else:
|
|
amp_type = None
|
|
amp_cfg = None
|
|
|
|
engine_cfg = gpc.config.get('engine', dict())
|
|
schedule_cfg = engine_cfg.pop('schedule', None)
|
|
|
|
schedule_type = None
|
|
if schedule_cfg is not None:
|
|
schedule_type = schedule_cfg.get('type', None)
|
|
|
|
if schedule_type is not None:
|
|
# run customized schedule
|
|
schedule_cfg['amp_type'] = amp_type
|
|
schedule_cfg['amp_config'] = amp_cfg
|
|
schedule = build_schedule(schedule_cfg)
|
|
elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
|
assert schedule_cfg is not None, \
|
|
"Config 'engine.schedule' not found in your configuration file for pipeline parallel training"
|
|
schedule = PipelineSchedule(
|
|
amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy())
|
|
else:
|
|
schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg)
|
|
|
|
engine = Engine(
|
|
model=model,
|
|
optimizer=optimizer,
|
|
criterion=criterion,
|
|
step_schedule=schedule,
|
|
**gpc.config.get('engine', dict())
|
|
)
|
|
|
|
return engine, train_dataloader, test_dataloader
|