ColossalAI/colossalai/auto_parallel
Edenzzzz d83c633ca6
[hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606)
* fix no pad token bug

* fixed some auto parallel codegen bug, but might not run on torch 2.1

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
2024-04-18 18:15:50 +08:00
..
checkpoint [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00
meta_profiler [hotfix] Fix wrong import in meta_registry (#5392) 2024-02-20 19:24:43 +08:00
offload [hotfix] Fix examples no pad token & auto parallel codegen bug; (#5606) 2024-04-18 18:15:50 +08:00
passes [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00
pipeline_shard [autoparallel] init new folder structure (#1696) 2022-10-13 14:18:55 +08:00
tensor_shard [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00
README.md [NFC] fix typo applications/ and colossalai/ (#3735) 2023-05-15 11:46:25 +08:00
__init__.py [autoparallel] standardize the code structure (#1469) 2022-08-19 15:51:54 +08:00

README.md

Colossal-AUTO

Challenges

Recently, large models have achieved the state of the art performances in various fields. In order to support large model training, we have to use distributed training techniques. However, finding an efficient distributed execution plan not only requires fine-grained model statistics, such as memory and computing overhead of each operator but also is a labor-intensive task even for an expert in the field of distributed training.

Our solution

To simplify the process of distributed training for foundational models, recent advancements in machine learning systems have led to the emergence of automatic parallel systems. We investigate and research a number of current automatic parallel systems( Tofu , Flexflow , Alpa ) and some auto activation checkpoint algorithms( Rotor , Sublinear ). Inspired from these advanced systems, we build an automatic parallel system upon PyTorch framework. The input of the system is the serial PyTorch code, and the output is a PyTorch program with an optimized distributed execution plan. It is worth emphasizing that the output is a regular PyTorch program, so it is compatible with runtime optimization methods, such as ZeRO-Offload and PatrickStar.

Key modules

Analyzer

Analyzer is a static analysis system consisting of three parts: A symbolic profiler for collecting computing and memory overhead related to static computation graph, a cluster detector for collecting hardware characteristics and detecting cluster topology and a tensor layout manager to find efficient tensor layout conversion path from different sharding spec and record conversion cost.

Solver

Solver is designed to find the optimal execution plan for a given computation graph and cluster in two stages:

  1. Intra-op parallelism stage is to find the plan with the minimum total execution time of all nodes with respect to the constraint of the memory budget. The optimization goal of intra-op parallelism solver is modified from Alpa 's intra-op parallelism ILP solver.
  2. Activation checkpoint stage is to search for the fastest execution plan that meets the memory budget on the computation graph after inserting the communication nodes by the intra-op parallelism stage. The algorithm to find optimal activation checkpoint is modified from Rotor . The reason we use two-stage optimization is that if the two tasks are formulated together, the solving time will be significantly increased, which will greatly affect the user experience of the system. On the contrary, solving in two hierarchical levels has many advantages. Firstly, compared with the computation graph with activation checkpointing, the original graph has fewer nodes, which can reduce the solving cost of intra-op parallelism solver. In addition, a more optimal solution can be found by adding the communication overhead into the activation checkpoint modeling.

Generator

Generator applies the searched execution plan to the computation graph and recompiles the computation graph to optimized PyTorch code. It has a series compile pass to insert a communication node or do the kernel substitution as the intra-op parallelism solver required. Additionally, we implement a code generation feature to recognize the annotation from the activation checkpoint solver and inject the activation checkpoint block following annotation instructions.