ColossalAI/colossalai/auto_parallel
YuliangLiu0306 ffcdbf0f65
[autoparallel]integrate auto parallel feature with new tracer (#3408)
* [autoparallel] integrate new analyzer in module level

* unify the profiling method

* polish

* fix no codegen bug

* fix pass bug

* fix liveness test

* polish
2023-04-04 17:40:45 +08:00
..
checkpoint [autoparallel] rotor solver refactor (#2813) 2023-02-18 11:30:15 +08:00
meta_profiler [autoparallel]integrate auto parallel feature with new tracer (#3408) 2023-04-04 17:40:45 +08:00
offload [zero] reorganize zero/gemini folder structure (#3424) 2023-04-04 13:48:16 +08:00
passes [autoparallel]integrate auto parallel feature with new tracer (#3408) 2023-04-04 17:40:45 +08:00
pipeline_shard [autoparallel] init new folder structure (#1696) 2022-10-13 14:18:55 +08:00
tensor_shard [autoparallel]integrate auto parallel feature with new tracer (#3408) 2023-04-04 17:40:45 +08:00
README.md [hotfix] add copyright for solver and device mesh (#2803) 2023-02-18 21:14:38 +08:00
__init__.py [autoparallel] standardize the code structure (#1469) 2022-08-19 15:51:54 +08:00

README.md

Colossal-AUTO

Challenges

Recently, large models have achieved the state of the art performances in various fields. In order to support large model training, we have to use distributed training techniques. However, finding an efficient distributed execution plan not only requires fine-grained model statistics, such as memory and computing overhead of each operator but also is a labor-intensive task even for an expert in the field of distributed training.

Our solution

To simplify the process of distributed training for foundational models, recent advancements in machine learning systems have led to the emergence of automatic parallel systems. We investigate and research a number of current automatic parallel systems( Tofu , Flexflow , Alpa ) and some auto activation checkpoint algorithms( Rotor , Sublinear ). Inspired from these advanced systems, we build an automatic parallel system upon PyTorch framework. The input of the system is the serial PyTorch code, and the output is a PyTorch program with an optimized distributed execution plan. It is worth emphasizing that the output is a regular PyTorch program, so it is compatible with runtime optimization methods, such as ZeRO-Offload and PatrickStar.

Key modules

Analyzer

Analyzer is a static analysis system consisting of three parts: A symbolic profiler for collecting computing and memory overhead related to static computation graph, a cluster detector for collecting hardware characteristics and detecting cluster topology and a tensor layout manager to find efficient tensor layout conversion path from different sharding spec and record conversion cost.

Solver

Solver is designed to find the optimal execution plan for a given computation graph and cluster in two stages:

  1. Intra-op parallelism stage is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimaztion goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelsim ILP solver.
  2. Activation checkpoint stage is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimial activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling.

Generator

Generator applies the searched execution plan to the computation graph and recompiles the computation graph to optimized PyTorch code. It has a series compile pass to insert a communication node or do the kernel substitution as the intra-op parallelism solver required. Additionally, we implement a code generation feature to recognize the annotation from the activation checkpoint solver and inject the activation checkpoint block following annotation instructions.