#!/usr/bin/env python # -*- encoding: utf-8 -*- from enum import Enum # parallel modes class ParallelMode(Enum): """This is an enumeration class containing all possible parallel modes. """ GLOBAL = 'global' # common parallel DATA = 'data' # pipeline parallel PIPELINE = 'pipe' PIPELINE_PREV = 'pipe_prev' PIPELINE_NEXT = 'pipe_next' # containing all ranks in tensor parallel TENSOR = 'tensor' # sequence parallel SEQUENCE = 'sequence' # 1D Parallel PARALLEL_1D = '1d' # 2D parallel PARALLEL_2D_ROW = '2d_row' PARALLEL_2D_COL = '2d_col' # 3D parallel PARALLEL_3D_INPUT = '3d_input' PARALLEL_3D_WEIGHT = '3d_weight' PARALLEL_3D_OUTPUT = '3d_output' # 2.5D parallel PARALLEL_2P5D_ROW = '2p5d_row' PARALLEL_2P5D_COL = '2p5d_col' PARALLEL_2P5D_DEP = '2p5d_dep' PARALLEL_2P5D_XZ = '2p5d_xz'