#!/usr/bin/env python # -*- encoding: utf-8 -*- import argparse import pprint import random from pathlib import Path from typing import Callable, Iterable, Optional, Union from typing import Tuple import numpy as np import torch from torch.utils.data import DataLoader from colossalai.engine import AMP_TYPE, NoPipelineSchedule, PipelineSchedule from colossalai.engine import Engine from colossalai.logging import get_global_dist_logger, init_global_dist_logger from colossalai.nn import DataParallelSampler from colossalai.nn.model.base_model import BaseModel from .builder import (ModelInitializer, build_dataset, build_loss, build_model, build_optimizer, build_optimizer_wrapper, build_schedule) from .context import Config, ParallelMode from .core import global_context as gpc from .utils import get_current_device, sync_model_param_in_dp def parse_args(): '''Reads user command line and uses an argument parser to parse the input arguments. Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. :return: call the parse arguments function :rtype: Namespace ''' parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, help='path to the config file') parser.add_argument('--host', type=str, default=None, help='the master address for distributed training') parser.add_argument('--port', type=str, default=None, help='the master port for distributed training') parser.add_argument('--world_size', type=int, help='world size for ') parser.add_argument('--local_rank', type=int, help='rank for the default process group') parser.add_argument('--backend', type=str, default='nccl', help='backend for torch.distributed') return parser.parse_args() def init_dist(config: Union[str, dict] = None, local_rank: int = None, world_size: int = None, host: str = None, port: str = None, backend: str = None): '''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given. Then initialize and set distributed environment by calling global_context's functions. :param config: config file or config file path are both acceptable :type config: Union[str, dict], optional :param local_rank: rank for the default process group, defaults to None :type local_rank: int, optional :param world_size: world size of GPUs, defaults to None :type world_size: int, optional :param host: the master address for distributed training, defaults to None :type host: str, optional :param port: the master port for distributed training, defaults to None :type port: str, optional :param backend: backend for torch.distributed, defaults to None :type backend: str, optional :raises Exception: raise exception when config type is wrong ''' args = [config, local_rank, world_size, host, port, backend] arg_given = [arg is not None for arg in args] if not all(arg_given): args = parse_args() if config is None: config = args.config if local_rank is None: local_rank = args.local_rank if world_size is None: world_size = args.world_size if host is None: host = args.host if port is None: port = args.port if backend is None: backend = args.backend args = Config( dict(config=config, host=host, port=port, world_size=world_size, local_rank=local_rank, backend=backend)) # set distributed settings dist_args = Config( dict(local_rank=args.local_rank, world_size=args.world_size, backend=args.backend)) gpc.set_dist_args(dist_args) # set config if isinstance(args.config, dict): cfg = args.config elif isinstance(args.config, (str, Path)): cfg = Config.from_file(args.config) else: raise Exception('Config type error: {}'.format(type(args.config))) gpc.load_config(cfg) # init dist groups gpc.init_global_dist(args.host, args.port) gpc.init_parallel_groups() # init dist logger init_global_dist_logger() # set cuda device if torch.cuda.is_available(): gpc.set_device() def get_dataloader(dataset, seed=1024, add_sampler_if_possible=False, **kwargs): '''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) .. note: when pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data on the 1st stage and label on the last stage :param dataset: a :class:utils.data.dataset dataset :param seed: random worker seed, defaults to 1024 :type seed: int, optional :param add_sampler_if_possible: [description], defaults to False :type add_sampler_if_possible: bool, optional :return: a :class:utils.data.dataset dataloader :rtype: torch.utils.data.dataset ''' _kwargs = kwargs.copy() if 'shuffle' in _kwargs: shuffle = _kwargs.pop('shuffle') else: shuffle = False if add_sampler_if_possible and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: sampler = DataParallelSampler(dataset, shuffle=shuffle) else: sampler = None # Deterministic dataloader def seed_worker(worker_id): worker_seed = seed np.random.seed(worker_seed) torch.manual_seed(worker_seed) random.seed(worker_seed) if sampler is None: return DataLoader(dataset, worker_init_fn=seed_worker, shuffle=shuffle, **_kwargs) else: return DataLoader(dataset, sampler=sampler, worker_init_fn=seed_worker, **_kwargs) def initialize(config: Union[str, dict] = None, local_rank: int = None, world_size: int = None, host: str = None, port: str = None, backend: str = None, train_dataloader: Optional[Union[Iterable, Callable]] = None, test_dataloader: Optional[Union[Iterable, Callable]] = None, ) -> Tuple[Engine, DataLoader, DataLoader]: '''Core function that initializes distributed environment, logger, cudnn, data, model, loss function, optimizer, and lr_scheduler(their configs are in gpc.config). :param config: config file or config file path are both acceptable :type config: Union[str, dict], optional :param local_rank: rank for the default process group, defaults to None :type local_rank: int, optional :param world_size: world size of GPUs, defaults to None :type world_size: int, optional :param host: the master address for distributed training, defaults to None :type host: str, optional :param port: the master port for distributed training, defaults to None :type port: str, optional :param backend: backend for torch.distributed, defaults to None :type backend: str, optional :param train_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None :type train_dataloader: Optional[Union[Iterable, Callable]], optional :param test_dataloader: If None, the config is used to build a dataloder; Else, it should be a dataloader object or a function with no arguments which can build a dataloader, defaults to None :type test_dataloader: Optional[Union[Iterable, Callable]], optional :return: (engine, train_dataloader, test_dataloader, criterion) :rtype: tuple ''' # initialize distributed environment init_dist(config=config, local_rank=local_rank, world_size=world_size, host=host, port=port, backend=backend) # init logger logger = get_global_dist_logger() logger.info(f'Distributed environment is initialized, ' f'data parallel size: {gpc.data_parallel_size}, pipeline parallel size: {gpc.pipeline_parallel_size}, ' f'tensor parallel size: {gpc.tensor_parallel_size}', ranks=[0]) # print config logger.info(f"\n========== Your Config ========\n" f"{pprint.pformat(gpc.config)}\n" f"================================", ranks=[0]) # cudnn cudnn_benchmark = gpc.config.get('cudnn_benchmark', True) cudnn_deterministic = gpc.config.get('cudnn_deterministic', False) torch.backends.cudnn.benchmark = cudnn_benchmark torch.backends.cudnn.deterministic = cudnn_deterministic logger.info( f"cuDNN benchmark = {cudnn_benchmark}, deterministic = {cudnn_deterministic}", ranks=[0]) # set seed, cuda seed is only set when cuda is avail gpc.set_seed() # return_items = list() # check fp16 and zero should_convert_model_to_half = False should_wrap_fp16_optimizer = False should_wrap_zero_optimizer_level_2_3 = False if hasattr(gpc.config, 'fp16'): fp16_mode = gpc.config.fp16.mode if fp16_mode == AMP_TYPE.PARALLEL: should_convert_model_to_half = True should_wrap_fp16_optimizer = True if hasattr(gpc.config, 'zero'): should_wrap_zero_optimizer_level_2_3 = True zero_type = gpc.config.zero.type if zero_type in ['ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3']: should_convert_model_to_half = True assert not should_wrap_fp16_optimizer, \ 'AMP_TYPE.PARALLEL is mutually exclusive with zero level 2 and 3' # build model logger.info('Building model ...', ranks=[0]) assert hasattr( gpc.config, 'model'), "Build error: configuration 'model' is missing" if gpc.pipeline_parallel_size > 1: model = ModelInitializer(gpc.config.model, 1, verbose=True) model = model.model_initialize() else: model = build_model(gpc.config.model) if isinstance(model, BaseModel): model.build_from_cfg() model = model.to(get_current_device()) sync_model_param_in_dp(model) logger.info('Model is created', ranks=[0]) if should_convert_model_to_half: model = model.half() logger.info("Model is cast to fp16", ranks=[0]) # training data if callable(train_dataloader): logger.info( f'Build train data loader from {train_dataloader}', ranks=[0]) train_dataloader = train_dataloader() if train_dataloader is None and hasattr(gpc.config, 'train_data'): logger.info('Preparing data ...', ranks=[0]) # assert hasattr(gpc.config, 'train_data'), "Build error: configuration 'train_data' is missing." train_dataset = build_dataset(gpc.config.train_data.dataset) logger.info('Train dataset is ready.', ranks=[0]) train_dataloader = get_dataloader(train_dataset, gpc.config.get('seed', 1024), True, **gpc.config.train_data.dataloader, ) logger.info( f'Loaded {len(train_dataset)} samples in {len(train_dataloader)} batches for training', ranks=[0]) if callable(test_dataloader): logger.info( f'Build test data loader from {test_dataloader}', ranks=[0]) test_dataloader = test_dataloader() # testing data, allowed to be None if test_dataloader is None and hasattr(gpc.config, 'test_data'): test_dataset = build_dataset(gpc.config.test_data.dataset) test_dataloader = get_dataloader( test_dataset, add_sampler_if_possible=True, **gpc.config.test_data.dataloader) logger.info( f'Loaded {len(test_dataset)} samples in {len(test_dataloader)} batches for testing', ranks=[0]) # build loss function assert hasattr(gpc.config, 'loss'), \ 'Build error: configuration \'loss\' is missing.' criterion = build_loss(gpc.config.loss) logger.info('Loss function is created', ranks=[0]) # build optimizer assert hasattr(gpc.config, 'optimizer'), \ "Build error: configuration 'optimizer' is missing." optim_type = gpc.config.optimizer.type is_pytorch_native_zero_level_1 = optim_type == 'ZeroRedundancyOptimizer' if is_pytorch_native_zero_level_1: original_cfg_copy = gpc.config.optimizer.copy() original_cfg_copy.pop('type') cfg = dict(type=optim_type, process_group=gpc.get_group( ParallelMode.DATA), **original_cfg_copy) optimizer = build_optimizer(cfg, model) else: optimizer = build_optimizer(gpc.config.optimizer, model) if should_wrap_zero_optimizer_level_2_3: optimizer = build_optimizer_wrapper(gpc.config.zero, optimizer, model) if should_wrap_fp16_optimizer: # replace the field mode with type fp16_cfg = gpc.config.fp16.copy() amp_type = fp16_cfg.pop('mode') assert amp_type == AMP_TYPE.PARALLEL, 'FP Optimizer should only be used for AMP_TYPE.PARALLEL' fp16_cfg['type'] = 'FP16Optimizer' optimizer = build_optimizer_wrapper(fp16_cfg, optimizer) logger.info('Optimizer is created', ranks=[0]) # build schedule and engine if hasattr(gpc.config, 'fp16'): amp_type = gpc.config.fp16.mode amp_cfg = gpc.config.fp16.copy() amp_cfg.pop('mode') else: amp_type = None amp_cfg = None engine_cfg = gpc.config.get('engine', dict()) schedule_cfg = engine_cfg.pop('schedule', None) schedule_type = None if schedule_cfg is not None: schedule_type = schedule_cfg.get('type', None) if schedule_type is not None: # run customized schedule schedule_cfg['amp_type'] = amp_type schedule_cfg['amp_config'] = amp_cfg schedule = build_schedule(schedule_cfg) elif gpc.is_initialized(ParallelMode.PIPELINE) and gpc.get_world_size(ParallelMode.PIPELINE) > 1: assert schedule_cfg is not None, \ "Config 'engine.schedule' not found in your configuration file for pipeline parallel training" schedule = PipelineSchedule( amp_type=amp_type, amp_config=amp_cfg, **schedule_cfg.copy()) else: schedule = NoPipelineSchedule(amp_type=amp_type, amp_config=amp_cfg) engine = Engine( model=model, optimizer=optimizer, criterion=criterion, step_schedule=schedule, **gpc.config.get('engine', dict()) ) return engine, train_dataloader, test_dataloader