# 配置文件 下方代码块中的示例展示了如何在CIFAR10数据集上使用Colossal-AI训练ViT模型。 ```python # build train_dataset and train_dataloader from this dictionary # It is not compulsory in Config File, instead, you can input this dictionary as an argument into colossalai.initialize() train_data = dict( # dictionary for building Dataset dataset=dict( # the type CIFAR10Dataset has to be registered type='CIFAR10Dataset', root='/path/to/data', # transform pipeline transform_pipeline=[ dict(type='Resize', size=IMG_SIZE), dict(type='RandomCrop', size=IMG_SIZE, padding=4), dict(type='RandomHorizontalFlip'), dict(type='ToTensor'), dict(type='Normalize', mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), ] ), # dictionary for building Dataloader dataloader=dict( batch_size=BATCH_SIZE, pin_memory=True, # num_workers=1, shuffle=True, ) ) # build test_dataset and test_dataloader from this dictionary test_data = dict( dataset=dict( type='CIFAR10Dataset', root='/path/to/data', train=False, transform_pipeline=[ dict(type='Resize', size=IMG_SIZE), dict(type='ToTensor'), dict(type='Normalize', mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010] ), ] ), dataloader=dict( batch_size=BATCH_SIZE, pin_memory=True, # num_workers=1, ) ) # compulsory # build optimizer from this dictionary optimizer = dict( # Avaluable types: 'ZeroRedundancyOptimizer_Level_1', 'ZeroRedundancyOptimizer_Level_2', 'ZeroRedundancyOptimizer_Level_3' # 'Adam', 'Lamb', 'SGD', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'FP16Optimizer' type='Adam', lr=0.001, weight_decay=0 ) # compulsory # build loss function from this dictionary loss = dict( # Avaluable types: # 'CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D' type='CrossEntropyLoss2D', ) # compulsory # build model from this dictionary model = dict( # types avaluable: 'PretrainBERT', 'VanillaResNet', 'VisionTransformerFromConfig' type='VisionTransformerFromConfig', # each key-value pair above refers to a layer # input data pass through these layers recursively tensor_splitting_cfg=dict( type='ViTInputSplitter2D', ), embedding_cfg=dict( type='ViTPatchEmbedding2D', img_size=IMG_SIZE, patch_size=PATCH_SIZE, embed_dim=DIM, ), token_fusion_cfg=dict( type='ViTTokenFuser2D', img_size=IMG_SIZE, patch_size=PATCH_SIZE, embed_dim=DIM, drop_rate=0.1 ), norm_cfg=dict( type='LayerNorm2D', normalized_shape=DIM, eps=1e-6, ), block_cfg=dict( # ViTBlock is a submodule type='ViTBlock', attention_cfg=dict( type='ViTSelfAttention2D', hidden_size=DIM, num_attention_heads=NUM_ATTENTION_HEADS, attention_dropout_prob=0., hidden_dropout_prob=0.1, checkpoint=True ), droppath_cfg=dict( type='VanillaViTDropPath', ), mlp_cfg=dict( type='ViTMLP2D', in_features=DIM, dropout_prob=0.1, mlp_ratio=4, checkpoint=True ), norm_cfg=dict( type='LayerNorm2D', normalized_shape=DIM, eps=1e-6, ), ), head_cfg=dict( type='ViTHead2D', hidden_size=DIM, num_classes=NUM_CLASSES, ), embed_dim=DIM, depth=DEPTH, drop_path_rate=0., ) # hooks are built when initializing trainer # possible hooks: 'BaseHook', 'MetricHook','LoadCheckpointHook' # 'SaveCheckpointHook','LossHook', 'AccuracyHook', 'Accuracy2DHook' # 'LogMetricByEpochHook', 'TensorboardHook','LogTimingByEpochHook', 'LogMemoryByEpochHook' hooks = [ dict(type='LogMetricByEpochHook'), dict(type='LogTimingByEpochHook'), dict(type='LogMemoryByEpochHook'), dict(type='Accuracy2DHook'), dict(type='LossHook'), # dict(type='TensorboardHook', log_dir='./tfb_logs'), # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') ] # three keys: pipeline, tensor, data # if data=dict(size=1), which means no data parallelization, then there is no need to define it parallel = dict( pipeline=dict(size=1), tensor=dict(size=4, mode='2d'), ) # not compulsory # pipeline or no pipeline schedule fp16 = dict( mode=AMP_TYPE.PARALLEL, initial_scale=2 ** 8 ) # not compulsory # build learning rate scheduler lr_scheduler = dict( type='LinearWarmupLR', warmup_epochs=5 ) schedule = dict( num_microbatches=8 ) # training stopping criterion # you can give num_steps or num_epochs num_epochs = 60 # config logging path logging = dict( root_path='./logs' ) ```