ColossalAI/colossalai/tensor/d_tensor
Edenzzzz 43995ee436
[Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694)
* [feat] Add distributed lamb; minor fixes in DeviceMesh (#5476)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [hotfix] Improve tester precision by removing ZeRO on vanilla lamb (#5576)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [optim] add distributed came (#5526)

* test CAME under LowLevelZeroOptimizer wrapper

* test CAME TP row and col pass

* test CAME zero pass

* came zero add master and worker param id convert

* came zero test pass

* came zero test pass

* test distributed came passed

* reform code, Modify some expressions and add comments

* minor fix of test came

* minor fix of dist_came and test

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* minor fix of dist_came and test

* rebase dist-optim

* rebase dist-optim

* fix remaining comments

* add test dist came using booster api

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [optim] Distributed Adafactor (#5484)

* [feature] solve conflict; update optimizer readme;

* [feature] update optimize readme;

* [fix] fix testcase;

* [feature] Add transformer-bert to testcase;solve a bug related to indivisible shape (induction in use_zero and tp is row parallel);

* [feature] Add transformers_bert model zoo in testcase;

* [feature] add user documentation to docs/source/feature.

* [feature] add API Reference & Sample to optimizer Readme; add state check for bert exam;

* [feature] modify user documentation;

* [fix] fix readme format issue;

* [fix] add zero=0 in testcase; cached augment in dict;

* [fix] fix percision issue;

* [feature] add distributed rms;

* [feature] remove useless comment in testcase;

* [fix] Remove useless test; open zero test; remove fp16 test in bert exam;

* [feature] Extract distributed rms function;

* [feature] add booster + lowlevelzeroPlugin in test;

* [feature] add Start_with_booster_API case in md; add Supporting Information in md;

* [fix] Also remove state movement in base adafactor;

* [feature] extract factor function;

* [feature] add LowLevelZeroPlugin test;

* [fix] add tp=False and zero=True in logic;

* [fix] fix use zero logic;

* [feature] add row residue logic in column parallel factor;

* [feature] add check optim state func;

* [feature] Remove duplicate logic;

* [feature] update optim state check func and percision test bug;

* [fix] update/fix optim state; Still exist percision issue;

* [fix] Add use_zero check in _rms; Add plugin support info in Readme; Add Dist Adafactor init Info;

* [feature] removed print & comments in utils;

* [feature] uodate Readme;

* [feature] add LowLevelZeroPlugin test with Bert model zoo;

* [fix] fix logic in _rms;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [fix] remove comments in testcase;

* [feature] add zh-Han Readme;

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] refractor dist came; fix percision error; add low level zero test with bert model zoo; (#5676)

* [feature] daily update;

* [fix] fix dist came;

* [feature] refractor dist came; fix percision error; add low level zero test with bert model zoo;

* [fix] open rms; fix low level zero test; fix dist came test function name;

* [fix] remove redundant test;

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Feature] Add Galore (Adam, Adafactor) and distributed GaloreAdamW8bit (#5570)

* init: add dist lamb; add debiasing for lamb

* dist lamb tester mostly done

* all tests passed

* add comments

* all tests passed. Removed debugging statements

* moved setup_distributed inside plugin. Added dist layout caching

* organize better

* update comments

* add initial distributed galore

* add initial distributed galore

* add galore set param utils; change setup_distributed interface

* projected grad precision passed

* basic precision tests passed

* tests passed; located svd precision issue in fwd-bwd; banned these tests

* Plugin DP + TP tests passed

* move get_shard_dim to d_tensor

* add comments

* remove useless files

* remove useless files

* fix zero typo

* improve interface

* remove moe changes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix import

* fix deepcopy

* update came & adafactor to main

* fix param map

* fix typo

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [Hotfix] Remove one buggy test case from dist_adafactor for now (#5692)


Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: chongqichuizi875 <107315010+chongqichuizi875@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: duanjunwen <54985467+duanjunwen@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
2024-05-14 13:52:45 +08:00
..
README.md [misc] refactor launch API and tensor constructor (#5666) 2024-04-29 10:40:11 +08:00
__init__.py [Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694) 2024-05-14 13:52:45 +08:00
api.py [Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694) 2024-05-14 13:52:45 +08:00
comm_spec.py [hotfix] fixed memory usage of shardformer module replacement (#5122) 2023-11-28 15:38:26 +08:00
layout.py [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00
layout_converter.py [shardformer] refactor embedding resize (#5603) 2024-04-18 16:10:18 +08:00
misc.py [DTensor] refactor sharding spec (#2987) 2023-03-07 11:08:11 +08:00
sharding_spec.py [Feature] Distributed optimizers: Lamb, Galore, CAME and Adafactor (#5694) 2024-05-14 13:52:45 +08:00
utils.py [misc] update pre-commit and run all files (#4752) 2023-09-19 14:20:26 +08:00

README.md

đŸ”ĸ Distributed Tensor

📚 Table of Contents

🔗 Introduction

Distributed tensor is a type of tensor that is distributed across multiple devices. It is a wrapper of PyTorch tensor, and it is used to support distributed training. It can represent the device topology and tensor placement over the devices in the topology. It also provides a set of APIs to manipulate the distributed tensor.

📝 Design

Our implementation is inspired by the work Alpa, which unifies data parallelism and tensor parallelism as intra-op parallelism. It uses notations S to represent the sharded dimension and R to represent the replicated dimension. For example, given a 2D matrix, [S, R] represents the tensor is sharded over the first dimension.

Each sharded dimension will have a subscript to represent its placement over the devices. Assuming we have 4 GPUs and the GPUs are arranged in a 2 x 2 manner. Let's say we have a 2D matrix like below:

    [1,  2,  3,  4 ]
A = [4,  5,  6,  7 ]
    [8,  9,  10, 11]
    [12, 13, 14, 15]

[S0, R] would mean that the first dimension is sharded over the rows in the device topology.

| --------------------—————————————————————-|
|                     |                     |
|  [1,  2,  3,  4 ]   |  [1,  2,  3,  4 ]   |
|  [4,  5,  6,  7 ]   |  [4,  5,  6,  7 ]   |
|                     |                     |
| --------------------——————————————————-----
|                     |                     |
|  [8,  9,  10, 11]   |  [8,  9,  10, 11]   |
|  [12, 13, 14, 15]   |  [12, 13, 14, 15]   |
|                     |                     |
| --------------------——————————————————-----

[S01, R] would mean that the first dimension is sharded over both the row and column in the device topology.

| --------------------—————————————————————-|
|                     |                     |
|  [1,  2,  3,  4 ]   |  [4,  5,  6,  7 ]   |
|                     |                     |
| --------------------——————————————————-----
|                     |                     |
|  [8,  9,  10, 11]   |  [12, 13, 14, 15]   |
|                     |                     |
| --------------------——————————————————-----

🔨 Usage

A sample API usage is given below.

import torch

import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor import DTensor, ShardingSpec

colossalai.launch_from_torch()

# define your device mesh
# assume you have 4 GPUs
physical_mesh_id = torch.arange(0, 4)
mesh_shape = (2, 2)
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)

# define a tensor
a = torch.rand(16, 32).cuda()

# create sharding spec for the tensor
# assume the sharding spec is [S0, R]
dim_partition_dict = {0: [0]}
sharding_spec = ShardingSpec(a.dim(), dim_partition_dict)

# create a distributed tensor
d_tensor = DTensor(a, device_mesh, sharding_spec)
print(d_tensor)

global_tensor = d_tensor.to_global()
print(global_tensor)

🎈 Progress Log

  • Support layout conversion
  • Support sharding on 2D device mesh
  • Support sharding on 3D device mesh
  • Support sharding 4D device mesh
  • Support sharding info saving and offline tensor merge (we can save tensor as dtensor and gather the tensors back to the global tensor based on the sharding info in a single process in CPU, useful for distributed training checkpoint load and save.)