#!/usr/bin/env python # -*- encoding: utf-8 -*- from enum import Enum # parallel modes class ParallelMode(Enum): """This is an enumeration class containing all possible parallel modes. """ GLOBAL = 'global' # common parallel DATA = 'data' # model parallel - containing tensor and pipeline parallel groups # this is added to facilitate amp and grad clipping in hybrid parallel MODEL = 'model' # pipeline parallel PIPELINE = 'pipe' # containing all ranks in tensor parallel TENSOR = 'tensor' # sequence parallel SEQUENCE = 'sequence' SEQUENCE_DP = 'sequence_dp' # 1D Parallel PARALLEL_1D = '1d' # 2D parallel PARALLEL_2D_ROW = '2d_row' PARALLEL_2D_COL = '2d_col' # 3D parallel PARALLEL_3D_INPUT = '3d_input' PARALLEL_3D_WEIGHT = '3d_weight' PARALLEL_3D_OUTPUT = '3d_output' # 2.5D parallel PARALLEL_2P5D_ROW = '2p5d_row' PARALLEL_2P5D_COL = '2p5d_col' PARALLEL_2P5D_DEP = '2p5d_dep' PARALLEL_2P5D_XZ = '2p5d_xz'