[gemini] gemini support tensor parallelism. (#4942)

* [colossalai]fix typo

* [inference] Add smmoothquant for llama (#4904)

* [inference] add int8 rotary embedding kernel for smoothquant (#4843)

* [inference] add smoothquant llama attention (#4850)

* add smoothquant llama attention

* remove uselss code

* remove useless code

* fix import error

* rename file name

* [inference] add silu linear fusion for smoothquant llama mlp  (#4853)

* add silu linear

* update skip condition

* catch smoothquant cuda lib exception

* prcocess exception for tests

* [inference] add llama mlp for smoothquant (#4854)

* add llama mlp for smoothquant

* fix down out scale

* remove duplicate lines

* add llama mlp check

* delete useless code

* [inference] add smoothquant llama (#4861)

* add smoothquant llama

* fix attention accuracy

* fix accuracy

* add kv cache and save pretrained

* refactor example

* delete smooth

* refactor code

* [inference] add smooth function and delete useless code for smoothquant (#4895)

* add smooth function and delete useless code

* update datasets

* remove duplicate import

* delete useless file

* refactor codes (#4902)

* rafactor code

* add license

* add torch-int and smoothquant license

* Update flash_attention_patch.py

To be compatible with the new change in the Transformers library, where a new argument 'padding_mask' was added to forward function of attention layer.
https://github.com/huggingface/transformers/pull/25598

* [kernel] support pure fp16 for cpu adam and update gemini optim tests (#4921)

* [kernel] support pure fp16 for cpu adam (#4896)

* [kernel] fix cpu adam kernel for pure fp16 and update tests (#4919)

* [kernel] fix cpu adam

* [test] update gemini optim test

* [format] applied code formatting on changed files in pull request 4908 (#4918)

Co-authored-by: github-actions <github-actions@github.com>

* [gemini] support gradient accumulation (#4869)

* add test

* fix no_sync bug in low level zero plugin

* fix test

* add argument for grad accum

* add grad accum in backward hook for gemini

* finish implementation, rewrite tests

* fix test

* skip stuck model in low level zero test

* update doc

* optimize communication & fix gradient checkpoint

* modify doc

* cleaning codes

* update cpu adam fp16 case

* [hotfix] fix torch 2.0 compatibility (#4936)

* [hotfix] fix launch

* [test] fix test gemini optim

* [shardformer] fix vit

* [test] add no master test for low level zero plugin (#4934)

* [format] applied code formatting on changed files in pull request 4820 (#4886)

Co-authored-by: github-actions <github-actions@github.com>

* [nfc] fix some typo with colossalai/ docs/ etc. (#4920)

* [Refactor] Integrated some lightllm kernels into token-attention  (#4946)

* add some req for inference

* clean codes

* add codes

* add some lightllm deps

* clean codes

* hello

* delete rms files

* add some comments

* add comments

* add doc

* add lightllm deps

* add lightllm cahtglm2 kernels

* add lightllm cahtglm2 kernels

* replace rotary embedding with lightllm kernel

* add some commnets

* add some comments

* add some comments

* add

* replace fwd kernel att1

* fix a arg

* add

* add

* fix token attention

* add some comments

* clean codes

* modify comments

* fix readme

* fix bug

* fix bug

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>

* [test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test

* [inference] add reference and fix some bugs (#4937)

* add reference and fix some bugs

* update gptq init

---------

Co-authored-by: Xu Kai <xukai16@foxamil.com>

* [Inference]ADD Bench Chatglm2 script (#4963)

* add bench chatglm

* fix bug and make utils

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Pipeline inference] Combine kvcache with pipeline inference (#4938)

* merge kvcache with pipeline inference and refactor the code structure

* support ppsize > 2

* refactor pipeline code

* do pre-commit

* modify benchmark

* fix bench mark

* polish code

* add docstring and update readme

* refactor the code

* fix some logic bug of ppinfer

* polish readme

* fix typo

* skip infer test

* updated c++17 compiler flags (#4983)

* [Inference] Dynamic Batching Inference, online and offline (#4953)

* [inference] Dynamic Batching for Single and Multiple GPUs (#4831)

* finish batch manager

* 1

* first

* fix

* fix dynamic batching

* llama infer

* finish test

* support different lengths generating

* del prints

* del prints

* fix

* fix bug

---------

Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [inference] Async dynamic batching  (#4894)

* finish input and output logic

* add generate

* test forward

* 1

* [inference]Re push async dynamic batching (#4901)

* adapt to ray server

* finish async

* finish test

* del test

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* Revert "[inference]Re push async dynamic batching (#4901)" (#4905)

This reverts commit fbf3c09e67.

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Revert "[inference] Async dynamic batching  (#4894)" (#4909)

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* [infer]Add Ray Distributed Environment Init Scripts (#4911)

* Revert "[inference] Async dynamic batching  (#4894)"

This reverts commit fced140250.

* Add Ray Distributed Environment Init Scripts

* support DynamicBatchManager base function

* revert _set_tokenizer version

* add driver async generate

* add async test

* fix bugs in test_ray_dist.py

* add get_tokenizer.py

* fix code style

* fix bugs about No module named 'pydantic' in ci test

* fix bugs in ci test

* fix bugs in ci test

* fix bugs in ci test

* support dynamic batch for bloom model and is_running function

* [Inference]Test for new Async engine (#4935)

* infer engine

* infer engine

* test engine

* test engine

* new manager

* change step

* add

* test

* fix

* fix

* finish test

* finish test

* finish test

* finish test

* add license

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>

* add assertion for config (#4947)

* [Inference] Finish dynamic batching offline test (#4948)

* test

* fix test

* fix quant

* add default

* fix

* fix some bugs

* fix some bugs

* fix

* fix bug

* fix bugs

* reset param

---------

Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497outlook.com>

* [Kernels]Updated Triton kernels into 2.1.0 and adding flash-decoding for llama token attention  (#4965)

* adding flash-decoding

* clean

* adding kernel

* adding flash-decoding

* add integration

* add

* adding kernel

* adding kernel

* adding triton 2.1.0 features for inference

* update bloom triton kernel

* remove useless vllm kernels

* clean codes

* fix

* adding files

* fix readme

* update llama flash-decoding

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* fix ColossalEval (#4992)

Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>

* [doc]Update doc for colossal-inference (#4989)

* update doc

* Update README.md

---------

Co-authored-by: cuiqing.li <lixx336@gmail.com>

* [hotfix] Fix the bug where process groups were not being properly released. (#4940)

* Fix the bug where process groups were not being properly released.

* test

* Revert "test"

This reverts commit 479900c139.

* [hotfix] fix the bug of repeatedly storing param group (#4951)

* [doc] add supported feature diagram for hybrid parallel plugin (#4996)

* [Pipeline Inference] Merge pp with tp (#4993)

* refactor pipeline into new CaiInferEngine

* updata llama modeling forward

* merge tp with pp

* update docstring

* optimize test workflow and example

* fix typo

* add assert and todo

* [release] update version (#4995)

* [release] update version

* [hotfix] fix ci

* [gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

[gemini] gemini support tp

* fix

fix

fix

* update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

update checkpointIO

* support fused layernorm

support fused layernorm

support fused layernorm

* update fusedlayernorm

update fusedlayernorm

update fusedlayernorm

* add sequence parallel to gemini

add sequence parallel to gemini

* fix

* fix comments

fix comments

fix comments

* fix

* fix t5

* clear cache

* fix

* activate ci

* activate ci

* fix

* fix

* fix

* fix

* revert

* modify tp gather method

modify tp gather method

modify tp gather method

modify tp gather method

* fix test

---------

Co-authored-by: Xu Kai <xukai16@foxmail.com>
Co-authored-by: Zian(Andy) Zheng <62330719+Orion-Zheng@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: Cuiqing Li <lixx3527@gmail.com>
Co-authored-by: cuiqing.li <lixx336@gmail.com>
Co-authored-by: CjhHa1 <cjh18671720497@outlook.com>
Co-authored-by: Xu Kai <xukai16@foxamil.com>
Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com>
Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: yuehuayingxueluo <867460659@qq.com>
Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com>
Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com>
Co-authored-by: littsk <1214689160@qq.com>
Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com>
pull/5044/head
flybird11111 2023-11-10 10:15:16 +08:00 committed by GitHub
parent a4489384d5
commit 576a2f7b10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 390 additions and 67 deletions

View File

@ -5,6 +5,7 @@ from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple from typing import Callable, Iterator, List, Optional, Tuple
import torch import torch
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
@ -19,8 +20,9 @@ from colossalai.checkpoint_io.utils import (
save_state_dict, save_state_dict,
save_state_dict_shards, save_state_dict_shards,
) )
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator, ProcessGroupMesh
from colossalai.interface import ModelWrapper, OptimizerWrapper from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.memory_tracer import MemStats from colossalai.zero.gemini.memory_tracer import MemStats
@ -32,7 +34,25 @@ __all__ = ["GeminiPlugin"]
SUPPORTED_PRECISION = ["fp16", "bf16"] SUPPORTED_PRECISION = ["fp16", "bf16"]
PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16} PRECISION_STR_TO_DTYPE = {"fp16": torch.half, "bf16": torch.bfloat16}
DP_AXIS = 0
TP_AXIS = 1
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
if optim is None:
return {}
param_info = {"id2shape": {}}
start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
original_shape = param.shape if isinstance(param, torch.Tensor) else None
param_info["id2shape"][param_id] = original_shape
start_index += len(group["params"])
return param_info
class GeminiCheckpointIO(GeneralCheckpointIO): class GeminiCheckpointIO(GeneralCheckpointIO):
def __init__(self) -> None: def __init__(self) -> None:
super().__init__() super().__init__()
@ -284,6 +304,16 @@ class GeminiPlugin(DPPluginBase):
max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do max_norm (float, optional): max_norm used for `clip_grad_norm`. You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm. clip_grad_norm by yourself when using ZeRO DDP. The ZeRO optimizer will take care of clip_grad_norm.
norm_type (float, optional): norm_type used for `clip_grad_norm`. norm_type (float, optional): norm_type used for `clip_grad_norm`.
enable_tensor_parallelism (bool, optional): Whether to use tensor parallelism strategy, which is implemented in Shardformer. Default to False.
tp_size (int, optional): If 'enable_tensor_parallelism' is set to true, please configure 'tp_size' which determines the size of the tensor parallel process group. Default to 1.
enable_all_optimization (bool, optional): Whether to switch on all the optimizations supported by Shardformer.
Currently all the optimization methods include fused normalization, flash attention and JIT.
Defaults to False.
enable_fused_normalization (bool, optional): Whether to switch on fused normalization in Shardformer. Defaults to False.
enable_flash_attention (bool, optional): Whether to switch on flash attention in Shardformer. Defaults to False.
enable_jit_fused (bool, optional): Whether to switch on JIT in Shardformer. Default to False.
enable_sequence_parallelism (bool): Whether to turn on sequence parallelism in Shardformer. Defaults to False.
enable_sequence_overlap (bool): Whether to turn on sequence overlap in Shardformer. Defaults to False.
verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False. verbose (bool, optional): verbose mode. Debug info including chunk search result will be printed. Defaults to False.
""" """
@ -317,6 +347,14 @@ class GeminiPlugin(DPPluginBase):
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
enable_tensor_parallelism: bool = False,
tp_size: int = 1,
enable_all_optimization: bool = False,
enable_fused_normalization: bool = False,
enable_flash_attention: bool = False,
enable_sequence_parallelism: bool = False,
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
verbose: bool = False, verbose: bool = False,
) -> None: ) -> None:
super().__init__() super().__init__()
@ -355,8 +393,32 @@ class GeminiPlugin(DPPluginBase):
max_norm=max_norm, max_norm=max_norm,
norm_type=norm_type, norm_type=norm_type,
) )
self.enable_tensor_parallelism = enable_tensor_parallelism
self.enable_all_optimization = enable_all_optimization
self.enable_fused_normalization = enable_fused_normalization
self.enable_flash_attention = enable_flash_attention
self.enable_sequence_parallelism = enable_sequence_parallelism if self.enable_tensor_parallelism else False
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_overlap = enable_sequence_overlap
self.verbose = verbose self.verbose = verbose
self.tp_size = tp_size if self.enable_tensor_parallelism else 1
self.dp_size = dist.get_world_size() // self.tp_size
assert self.dp_size > 1, f"The size of the DP group should be greater than 1. Please reduce the TP group size."
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.tp_size)
self.dp_group = self.pg_mesh.get_group_along_axis(DP_AXIS)
self.tp_group = self.pg_mesh.get_group_along_axis(TP_AXIS)
self.shard_config = ShardConfig(
tensor_parallel_process_group=self.tp_group,
enable_tensor_parallelism=self.enable_tensor_parallelism,
enable_all_optimization=self.enable_all_optimization,
enable_fused_normalization=self.enable_fused_normalization,
enable_flash_attention=self.enable_flash_attention,
enable_jit_fused=self.enable_jit_fused,
enable_sequence_parallelism=self.enable_sequence_parallelism,
enable_sequence_overlap=self.enable_sequence_overlap,
)
def support_no_sync(self) -> bool: def support_no_sync(self) -> bool:
return False return False
@ -380,6 +442,7 @@ class GeminiPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None, dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None, lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: ) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper): if not isinstance(model, ModelWrapper):
# convert model to sync bn # convert model to sync bn
# FIXME(ver217): gemini does not support sync bn # FIXME(ver217): gemini does not support sync bn
@ -391,11 +454,21 @@ class GeminiPlugin(DPPluginBase):
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None) # model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini # wrap the model with Gemini
model = GeminiDDP(model, **self.gemini_config, verbose=self.verbose) if self.enable_tensor_parallelism:
shardformer = ShardFormer(self.shard_config)
model, _ = shardformer.optimize(model)
model = GeminiDDP(model, **self.gemini_config, process_group=self.dp_group, verbose=self.verbose)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
optimizer = GeminiOptimizer( optimizer = GeminiOptimizer(
optimizer, model, **self.zero_optim_config, **self.optim_kwargs, verbose=self.verbose optimizer,
model,
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info,
verbose=self.verbose,
) )
return model, optimizer, criterion, dataloader, lr_scheduler return model, optimizer, criterion, dataloader, lr_scheduler
@ -407,4 +480,4 @@ class GeminiPlugin(DPPluginBase):
return GeminiCheckpointIO() return GeminiCheckpointIO()
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError raise NotImplementedError

View File

@ -225,3 +225,4 @@ class ProcessGroupMesh:
# no need to cache it explicitly, since it will be cached in `create_group_along_axis` # no need to cache it explicitly, since it will be cached in `create_group_along_axis`
return self.create_group_along_axis(axis, indices_at_axis, backend=backend) return self.create_group_along_axis(axis, indices_at_axis, backend=backend)
return self._ranks_to_group[ranks_in_group] return self._ranks_to_group[ranks_in_group]

View File

@ -53,7 +53,7 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
@ -62,13 +62,18 @@ class MatmulWithAsyncCommunication(torch.autograd.Function):
if bias is not None: if bias is not None:
output = output + bias output = output + bias
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias.
weight = weight.view(weight.shape)
bias = bias.view(bias.shape)
total_input = input total_input = input
grad_input = grad_output.matmul(weight.T) grad_input = grad_output.matmul(weight.T)
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
@ -100,7 +105,7 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce): def forward(ctx, input_, weight, bias, process_group, async_grad_allreduce):
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_allreduce = async_grad_allreduce ctx.async_grad_allreduce = async_grad_allreduce
@ -109,13 +114,18 @@ class LinearWithAsyncCommunication(torch.autograd.Function):
output = F.linear(input_, weight, bias) output = F.linear(input_, weight, bias)
else: else:
output = F.linear(input_, weight) output = F.linear(input_, weight)
return output return output
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, weight = ctx.saved_tensors input, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to bias.
if use_bias:
bias.view(bias.shape)
total_input = input total_input = input
grad_input = grad_output.matmul(weight) grad_input = grad_output.matmul(weight)
grad_output = grad_output.contiguous() grad_output = grad_output.contiguous()
@ -152,7 +162,7 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap=True):
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
@ -170,12 +180,16 @@ class _LinearWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
dim = ctx.dim dim = ctx.dim
process_group = ctx.process_group process_group = ctx.process_group
overlap = ctx.overlap overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
if use_bias:
bias = bias.view(bias.shape)
if not overlap: if not overlap:
input_parallel = _gather(input_, dim, process_group) input_parallel = _gather(input_, dim, process_group)
@ -289,7 +303,7 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@staticmethod @staticmethod
def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap): def forward(ctx, input_, weight, bias, process_group, async_grad_reduce_scatter, dim, overlap):
ctx.save_for_backward(input_, weight) ctx.save_for_backward(input_, weight, bias)
ctx.use_bias = bias is not None ctx.use_bias = bias is not None
ctx.process_group = process_group ctx.process_group = process_group
ctx.async_grad_reduce_scatter = async_grad_reduce_scatter ctx.async_grad_reduce_scatter = async_grad_reduce_scatter
@ -306,12 +320,17 @@ class _MatmulWithGatherForwardReduceScatterBackward(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors input_, weight, bias = ctx.saved_tensors
use_bias = ctx.use_bias use_bias = ctx.use_bias
dim = ctx.dim dim = ctx.dim
process_group = ctx.process_group process_group = ctx.process_group
overlap = ctx.overlap overlap = ctx.overlap
# In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm
weight = weight.view(weight.shape)
if use_bias:
bias = bias.view(bias.shape)
if not overlap: if not overlap:
input_parallel = _gather(input_, dim, process_group) input_parallel = _gather(input_, dim, process_group)
@ -454,6 +473,29 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
@staticmethod @staticmethod
def backward(ctx, grad_output): def backward(ctx, grad_output):
return _split(grad_output, ctx.dim, ctx.process_group), None, None return _split(grad_output, ctx.dim, ctx.process_group), None, None
class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
output = input
return output
@staticmethod
def backward(ctx, grad_output):
weight, bias = ctx.saved_tensors
if weight is not None:
weight = weight.view(weight.shape)
if bias is not None:
bias = bias.view(bias.shape)
return grad_output, None, None
def hook_paramter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group): def _reduce(input_, process_group):

View File

@ -309,7 +309,8 @@ class VocabParallelEmbedding1D(ParallelModule):
) )
# Mask the output embedding. # Mask the output embedding.
output_parallel[input_mask, :] = 0.0 embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs. # Reduce across all the model parallel GPUs.
output = reduce_forward(output_parallel, self.process_group) output = reduce_forward(embedding_output, self.process_group)
return output return output

View File

@ -1,15 +1,29 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import warnings
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import torch.nn as nn import torch.nn as nn
from colossalai.lazy import LazyInitContext from colossalai.lazy import LazyInitContext
from ._operation import hook_paramter_in_backward
from .utils import SeqParallelUtils from .utils import SeqParallelUtils
__all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"] __all__ = ["FusedLayerNorm", "FusedRMSNorm", "LayerNorm", "RMSNorm", "BaseLayerNorm"]
try:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
EnableFastLayerNorm = True
except ImportError:
EnableFastLayerNorm = False
try:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm
from apex.normalization import FusedRMSNorm as ApexFusedRMSNorm
except ImportError:
warnings.warn(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
)
FAST_LAYERNORM_SUPPORTED_SIZE = [ FAST_LAYERNORM_SUPPORTED_SIZE = [
1024, 1024,
1536, 1536,
@ -37,6 +51,34 @@ FAST_LAYERNORM_SUPPORTED_SIZE = [
65536, 65536,
] ]
if EnableFastLayerNorm:
class FastLayerNormWithHook(FastLayerNorm):
def __init__(self, hidden_size, eps=0.00001):
super().__init__(hidden_size, eps)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedLayerNormWithHook(ApexFusedLayerNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight, self.bias)
return output
class FusedRMSNormWithHook(ApexFusedRMSNorm):
def __init__(self, normalized_shape, eps=0.00001, elementwise_affine=True):
super().__init__(normalized_shape, eps, elementwise_affine)
def forward(self, input):
output = super().forward(input)
output = hook_paramter_in_backward(output, self.weight)
return output
class BaseLayerNorm(ABC): class BaseLayerNorm(ABC):
@abstractmethod @abstractmethod
@ -161,16 +203,6 @@ class FusedLayerNorm(BaseLayerNorm):
Raises: Raises:
AssertionError: If the provided module is not an instance of nn.LayerNorm. AssertionError: If the provided module is not an instance of nn.LayerNorm.
""" """
# check if apex is installed
assert isinstance(module, nn.LayerNorm), "Only support conversion from nn.LayerNorm."
try:
pass
except ImportError:
raise ImportError(
"Please install apex from source (https://github.com/NVIDIA/apex) to use the fused layernorm kernel"
)
LazyInitContext.materialize(module) LazyInitContext.materialize(module)
# get the attributes of the module # get the attributes of the module
@ -184,18 +216,17 @@ class FusedLayerNorm(BaseLayerNorm):
use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE use_fast_ln = normalized_shape in FAST_LAYERNORM_SUPPORTED_SIZE
if use_fast_ln: if use_fast_ln:
try: if EnableFastLayerNorm:
from apex.contrib.layer_norm.layer_norm import FastLayerNorm as ApexFusedLayerNorm ApexFusedLayerNorm = FastLayerNormWithHook
except ImportError: else:
# fall back to the normal fused layernorm is not built # fall back to the normal fused layernorm is not built
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm ApexFusedLayerNorm = FusedLayerNormWithHook
else: else:
from apex.normalization import FusedLayerNorm as ApexFusedLayerNorm ApexFusedLayerNorm = FusedLayerNormWithHook
layernorm = ( layernorm = (
ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device) ApexFusedLayerNorm(normalized_shape, eps=eps, elementwise_affine=elementwise_affine).to(dtype).to(device)
) )
layernorm.weight = module.weight layernorm.weight = module.weight
layernorm.bias = module.bias layernorm.bias = module.bias
@ -213,13 +244,12 @@ class FusedRMSNorm(BaseLayerNorm):
""" """
This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface. This is a wrapper around the apex fused rms norm implementation. It is meant to be used only with the from_native_module interface.
""" """
def __init__(self) -> None: def __init__(self) -> None:
raise NotImplementedError( raise NotImplementedError(
"FusedRMSNorm is not implemented as a physical class. " "FusedRMSNorm is not implemented as a physical class. "
"It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex." "It is meant to be used only with the from_native_module interface to Convert a native RMSNorm module to FusedRMSNorm module provided by apex."
) )
@staticmethod @staticmethod
def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module: def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *args, **kwargs) -> nn.Module:
r""" r"""
@ -252,7 +282,7 @@ class FusedRMSNorm(BaseLayerNorm):
eps = module.eps eps = module.eps
elementwise_affine = module.elementwise_affine elementwise_affine = module.elementwise_affine
rmsnorm = ApexFusedRMSNorm(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine) rmsnorm = FusedRMSNormWithHook(normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine)
rmsnorm.weight = module.weight rmsnorm.weight = module.weight

View File

@ -719,7 +719,7 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
): ):
fused_qkv = self.query_key_value(hidden_states) fused_qkv = self.query_key_value(hidden_states)
(query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
batch_size, tgt_len, _ = query_layer.size() batch_size, tgt_len, _, _ = query_layer.size()
_, kv_length, _, _ = key_layer.size() _, kv_length, _, _ = key_layer.size()
@ -755,6 +755,7 @@ def get_bloom_flash_attention_forward(enabel_jit_fused=False):
attention_numerical_mask = torch.masked_fill( attention_numerical_mask = torch.masked_fill(
attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min attention_numerical_mask, attention_mask, torch.finfo(torch.float32).min
) )
attention_numerical_mask = attention_numerical_mask.to(query_layer.dtype)
context_layer = me_attention( context_layer = me_attention(
query_layer, query_layer,

View File

@ -183,14 +183,6 @@ class T5BasePolicy(Policy):
policy=policy, policy=policy,
target_key=T5LayerFF, target_key=T5LayerFF,
) )
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="layer_norm",
target_module=norm_cls,
),
policy=policy,
target_key=T5LayerFF,
)
self.append_or_create_submodule_replacement( self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls), description=SubModuleReplacementDescription(suffix="layer_norm", target_module=norm_cls),
policy=policy, policy=policy,

View File

@ -2,7 +2,9 @@ from .api import (
compute_global_numel, compute_global_numel,
customized_distributed_tensor_to_param, customized_distributed_tensor_to_param,
distribute_tensor, distribute_tensor,
init_as_dtensor,
distribute_tensor_with_customization, distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh, get_device_mesh,
get_global_shape, get_global_shape,
get_layout, get_layout,
@ -23,6 +25,7 @@ from .sharding_spec import ShardingSpec
__all__ = [ __all__ = [
"is_distributed_tensor", "is_distributed_tensor",
"distribute_tensor", "distribute_tensor",
"init_as_dtensor",
"to_global", "to_global",
"is_sharded", "is_sharded",
"shard_rowwise", "shard_rowwise",
@ -36,6 +39,7 @@ __all__ = [
"get_layout", "get_layout",
"is_customized_distributed_tensor", "is_customized_distributed_tensor",
"distribute_tensor_with_customization", "distribute_tensor_with_customization",
"init_tensor_as_customization_distributed",
"to_global_for_customized_distributed_tensor", "to_global_for_customized_distributed_tensor",
"customized_distributed_tensor_to_param", "customized_distributed_tensor_to_param",
"Layout", "Layout",

View File

@ -128,6 +128,17 @@ def distribute_tensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_sp
return sharded_tensor return sharded_tensor
def init_as_dtensor(tensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec, global_shape: torch.Size) -> torch.Tensor:
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
dist_layout = Layout(device_mesh=device_mesh, sharding_spec=sharding_spec, global_shape=global_shape)
# shard tensor
tensor.dist_layout = dist_layout
# hack some tensor methods
_hijack_detach_and_clone(tensor)
return tensor
def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None: def redistribute(dtensor: torch.Tensor, device_mesh: DeviceMesh, sharding_spec: ShardingSpec) -> None:
""" """
@ -420,6 +431,54 @@ def distribute_tensor_with_customization(tensor: torch.Tensor, shard_fn, gather_
return sharded_tensor return sharded_tensor
def init_tensor_as_customization_distributed(tensor: torch.Tensor, shard_fn, gather_fn: callable):
"""
Distribute the given tensor with the given shard_fn and gather_fn.
Example:
```python
# define shard and gather functions
def shard_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
return tensor.chunk(world_size, dim=0)[rank]
def gather_fn(tensor):
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
shard_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(shard_list, tensor)
return torch.cat(shard_list, dim=0)
# create a distributed tensor
tensor = torch.rand(4, 4)
dtensor = init_tensor_as_customization_distributed(tensor, shard_fn, gather_fn)
```
Args:
tensor (torch.Tensor): The tensor to be distributed.
shard_fn (callable): The function to shard the tensor.
gather_fn (callable): The function to gather the tensor.
Returns:
torch.Tensor: The distributed tensor.
"""
assert callable(shard_fn), "The shard_fn must be callable."
assert callable(gather_fn), "The gather_fn must be callable."
assert not is_distributed_tensor(tensor), "The input tensor is already a distributed tensor."
# set the shard_fn and gather_fn as attributes of the distributed tensor
tensor.shard_fn = shard_fn
tensor.gather_fn = gather_fn
# set the shard_fn and gather_fn as attributes of the distributed tensor
_hijack_detach_and_clone_for_customized_distributed_tensor(tensor)
return tensor
def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor: def to_global_for_customized_distributed_tensor(dtensor: torch.Tensor) -> torch.Tensor:
""" """
Gather the given tensor to the global tensor. Gather the given tensor to the global tensor.

View File

@ -17,6 +17,7 @@ from colossalai.logging import get_dist_logger
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored from colossalai.utils import _cast_float, free_storage, get_current_device, is_ddp_ignored
from colossalai.checkpoint_io.utils import gather_distributed_param
from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager from .chunk import Chunk, ChunkManager, TensorState, init_chunk_manager
from .gemini_hook import GeminiZeROHook from .gemini_hook import GeminiZeROHook
@ -24,6 +25,18 @@ from .gemini_mgr import GeminiManager
from .memory_tracer import MemStats, OrderedParamGenerator from .memory_tracer import MemStats, OrderedParamGenerator
from .utils import get_temp_total_chunk_on_cuda from .utils import get_temp_total_chunk_on_cuda
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
get_global_shape,
init_as_dtensor
)
try: try:
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX, _IncompatibleKeys
except ImportError: except ImportError:
@ -318,9 +331,7 @@ class GeminiDDP(ModelWrapper):
self._post_backward() self._post_backward()
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): raise RuntimeError("Gemini is not compatible with pipeline. backward_by_grad shoudn't be called in Gemini.")
torch.autograd.backward(tensor, grad)
self._post_backward()
def grad_handle(self, p, grad): def grad_handle(self, p, grad):
setattr(p, "_gemini_reduced", True) setattr(p, "_gemini_reduced", True)
@ -431,7 +442,18 @@ class GeminiDDP(ModelWrapper):
record_tensor = torch.empty([0]) record_tensor = torch.empty([0])
record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0) record_flag = (not only_rank_0) | (dist.get_rank(chunk.torch_pg) == 0)
if record_flag: if record_flag:
record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).cpu() record_tensor = temp_chunk[tensor_info.offset : tensor_info.end].view(tensor.shape).to(tensor.device)
if is_distributed_tensor(tensor):
global_shape = get_global_shape(tensor)
device_mesh = get_device_mesh(tensor)
shard_spec = get_sharding_spec(tensor)
record_tensor = init_as_dtensor(record_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape = global_shape)
elif is_customized_distributed_tensor(tensor):
init_tensor_as_customization_distributed(record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
assert tensor not in chunk_to_save_data assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor chunk_to_save_data[tensor] = record_tensor
@ -606,10 +628,16 @@ class GeminiDDP(ModelWrapper):
local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items()) local_name_params = itertools.chain(self.named_parameters(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}
def load(param_name, dest_tensor, copy_func): def load(param_name, dest_tensor, copy_func, source_device_mesh=None, source_sharding_spec=None, shard_fn=None, gather_fn=None):
state_key = prefix + param_name state_key = prefix + param_name
if state_key in state_dict: if state_key in state_dict:
input_param = state_dict[state_key] input_param = state_dict[state_key]
if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None:
input_param = distribute_tensor_with_customization(input_param, shard_fn=shard_fn, gather_fn=gather_fn)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+ # Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1: if len(dest_tensor.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0] input_param = input_param[0]
@ -653,9 +681,19 @@ class GeminiDDP(ModelWrapper):
temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision) temp_chunk = get_temp_total_chunk_on_cuda(chunk, self.mixed_precision)
for tensor, tensor_info in chunk.tensors_info.items(): for tensor, tensor_info in chunk.tensors_info.items():
source_device_mesh, source_sharding_spec, shard_fn, gather_fn = None, None, None, None
if is_distributed_tensor(tensor):
# shard the input param
source_device_mesh = get_device_mesh(tensor)
source_sharding_spec = get_sharding_spec(tensor)
elif is_customized_distributed_tensor(tensor):
shard_fn = tensor.shard_fn
gather_fn = tensor.gather_fn
parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor] parameter_name = fp32_to_name[tensor] if self.reuse_fp16_chunk else self.param2name[tensor]
parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end] parameter_slice = temp_chunk[tensor_info.offset : tensor_info.end]
load(parameter_name, tensor, partial(load_parameter, parameter_slice)) load(parameter_name, tensor, partial(load_parameter, parameter_slice), source_device_mesh, source_sharding_spec, shard_fn, gather_fn)
if chunk.is_gathered: if chunk.is_gathered:
chunk.cuda_global_chunk.copy_(temp_chunk) chunk.cuda_global_chunk.copy_(temp_chunk)
@ -724,7 +762,8 @@ class GeminiDDP(ModelWrapper):
if self.master_weights: if self.master_weights:
# create a fp32 parameter # create a fp32 parameter
fp32_p = p.data.float() fp32_p = p.clone()
fp32_p.data = fp32_p.data.float()
self.chunk_manager.register_tensor( self.chunk_manager.register_tensor(
tensor=fp32_p, tensor=fp32_p,
group_type="fp32_param", group_type="fp32_param",

View File

@ -9,6 +9,7 @@ import torch.distributed as dist
from packaging.version import Version from packaging.version import Version
from torch.nn import Parameter from torch.nn import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.distributed import ProcessGroup
from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin from colossalai.amp.naive_amp.mixed_precision_mixin import BF16MixedPrecisionMixin, FP16MixedPrecisionMixin
from colossalai.checkpoint_io.utils import StateDictSharder from colossalai.checkpoint_io.utils import StateDictSharder
@ -19,6 +20,18 @@ from colossalai.utils import disposable, get_current_device, is_ddp_ignored
from .chunk import Chunk, ChunkManager from .chunk import Chunk, ChunkManager
from .gemini_ddp import GeminiDDP from .gemini_ddp import GeminiDDP
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
init_tensor_as_customization_distributed,
get_device_mesh,
get_sharding_spec,
is_customized_distributed_tensor,
is_distributed_tensor,
get_global_shape,
init_as_dtensor
)
__all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"] __all__ = ["GeminiOptimizer", "GeminiAdamOptimizer"]
@ -93,6 +106,8 @@ class GeminiOptimizer(OptimizerWrapper):
max_scale: float = 2**32, max_scale: float = 2**32,
max_norm: float = 0.0, max_norm: float = 0.0,
norm_type: float = 2.0, norm_type: float = 2.0,
tp_group: ProcessGroup = None,
optimizer_params_info=None,
verbose: bool = False, verbose: bool = False,
**defaults: Any, **defaults: Any,
): ):
@ -109,6 +124,10 @@ class GeminiOptimizer(OptimizerWrapper):
self.chunk16_set: Set[Chunk] = set() self.chunk16_set: Set[Chunk] = set()
self.clipping_flag = max_norm > 0.0 self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm self.max_norm = max_norm
self.tp_group = tp_group
self.optimizer_params_info = optimizer_params_info
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose self.verbose = verbose
self.param_groups_backup = list() self.param_groups_backup = list()
@ -406,8 +425,8 @@ class GeminiOptimizer(OptimizerWrapper):
param = self.id_to_real_params[param_id] param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None) fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param) chunk = self.chunk_manager.get_chunk(param)
process_group = chunk.torch_pg dp_group = chunk.torch_pg
rank = dist.get_rank(process_group) rank = dist.get_rank(dp_group)
master_rank = 0 master_rank = 0
collected_states = {} collected_states = {}
@ -415,9 +434,9 @@ class GeminiOptimizer(OptimizerWrapper):
local_state_names = None local_state_names = None
if fake_param is not None: if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys()) local_state_names = list(self.optim.state[fake_param].keys())
gathered_state_names = [None for _ in range(dist.get_world_size(process_group))] gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))]
dist.barrier() dist.barrier()
dist.all_gather_object(gathered_state_names, local_state_names) dist.all_gather_object(gathered_state_names, local_state_names, dp_group)
state_names = None state_names = None
for names in gathered_state_names: for names in gathered_state_names:
if names is not None: if names is not None:
@ -436,6 +455,13 @@ class GeminiOptimizer(OptimizerWrapper):
# Every rank is collector when only_rank_0 is False. # Every rank is collector when only_rank_0 is False.
is_collector = (rank == master_rank) or (not only_rank_0) is_collector = (rank == master_rank) or (not only_rank_0)
# get tensor parallelism information
is_dtensor = is_distributed_tensor(param)
is_customized_distributed = is_customized_distributed_tensor(param)
shard_spec = get_sharding_spec(param) if is_dtensor else None
device_mesh = get_device_mesh(param) if is_dtensor else None
global_shape = self.optimizer_params_info["id2shape"][param_id]
# If the chunk is kept gathered, # If the chunk is kept gathered,
# the parameteres are treated the same as that of those in strict DDP during training. # the parameteres are treated the same as that of those in strict DDP during training.
# So states can be directly fetched from current device. # So states can be directly fetched from current device.
@ -451,7 +477,18 @@ class GeminiOptimizer(OptimizerWrapper):
).cpu() ).cpu()
else: else:
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu() state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
collected_states[state_name] = torch.reshape(state_tensor, param.shape) if is_dtensor:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(state_tensor,
device_mesh=device_mesh,
sharding_spec=shard_spec,
global_shape = global_shape)
elif is_customized_distributed:
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
collected_states[state_name] = state_tensor.reshape(global_shape)
return collected_states return collected_states
# Check whether the param with given id is managed by current process. # Check whether the param with given id is managed by current process.
@ -473,7 +510,7 @@ class GeminiOptimizer(OptimizerWrapper):
_, shard_offset, shard_size = self.get_offsets(param_id) _, shard_offset, shard_size = self.get_offsets(param_id)
# Collectors gather state shards through all_gathering. # Collectors gather state shards through all_gathering.
gathered_state_shards = [None for _ in range(dist.get_world_size(process_group))] gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))]
dist.barrier() dist.barrier()
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
@ -494,6 +531,16 @@ class GeminiOptimizer(OptimizerWrapper):
for state_name, state_tensor in collected_states.items(): for state_name, state_tensor in collected_states.items():
if state_tensor.numel() == param.numel(): if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape) collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor:
state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(state_tensor,
sharding_spec=shard_spec,
device_mesh=device_mesh,
global_shape=global_shape)
elif is_customized_distributed:
state_tensor = state_tensor.to(param.device)
init_tensor_as_customization_distributed(state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
return collected_states return collected_states
@ -658,6 +705,14 @@ class GeminiOptimizer(OptimizerWrapper):
ret_val = torch.zeros( ret_val = torch.zeros(
state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False state_end - state_start, dtype=torch.float32, device=param.device, requires_grad=False
) )
if is_dtensor:
value = torch.reshape(value, global_shape)
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
elif is_customized_distributed:
value = torch.reshape(value, global_shape)
value = distribute_tensor_with_customization(value, real_param.shard_fn, real_param.gather_fn)
ret_val.copy_(value.flatten()[state_start:state_end]) ret_val.copy_(value.flatten()[state_start:state_end])
return ret_val return ret_val
@ -668,6 +723,15 @@ class GeminiOptimizer(OptimizerWrapper):
# Copy states assigned to param (and cast tensors to appropriate types). # Copy states assigned to param (and cast tensors to appropriate types).
updated_states = dict() updated_states = dict()
# get tensor parallelism information
real_param = self.id_to_real_params[param_id]
is_dtensor = is_distributed_tensor(real_param)
is_customized_distributed = is_customized_distributed_tensor(real_param)
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
device_mesh = get_device_mesh(real_param) if is_dtensor else None
global_shape = self.optimizer_params_info["id2shape"][param_id]
for k, v in saved_states.items(): for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, k) updated_states[k] = cast(fake_param, state_range, v, k)
del v # clean loaded states del v # clean loaded states

View File

@ -10,18 +10,21 @@ from colossalai.booster.plugin import GeminiPlugin
from colossalai.fx import is_compatible_with_meta from colossalai.fx import is_compatible_with_meta
from colossalai.lazy.lazy_init import LazyInitContext from colossalai.lazy.lazy_init import LazyInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.tensor.d_tensor.api import clear_layout_converter
from colossalai.shardformer.layer.utils import Randomizer
from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo from tests.kit.model_zoo import model_zoo
def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[str]: def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism) -> Optional[str]:
try: try:
if init_method == "lazy": if init_method == "lazy":
ctx = LazyInitContext() ctx = LazyInitContext()
else: else:
ctx = nullcontext() ctx = nullcontext()
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5) enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(max_norm=1.0, initial_scale=2**5, enable_tensor_parallelism=enable_tensor_parallelism, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
with ctx: with ctx:
model = model_fn() model = model_fn()
@ -46,6 +49,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
booster.backward(loss, optimizer) booster.backward(loss, optimizer)
optimizer.step() optimizer.step()
except NotImplementedError:
print(f"Tensor Parallelism policy for {model.__class__} is not implemented yet\n.")
except Exception as e: except Exception as e:
# raise e # raise e
return repr(e) return repr(e)
@ -57,7 +62,8 @@ def run_fn(init_method, model_fn, data_gen_fn, output_transform_fn) -> Optional[
@parameterize("subset", ["torchvision", "transformers", "diffusers"]) @parameterize("subset", ["torchvision", "transformers", "diffusers"])
@parameterize("init_method", ["none"]) @parameterize("init_method", ["none"])
def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool = True): @parameterize("enable_tensor_parallelism", [True, False])
def check_gemini_plugin(subset: str, init_method: str = "none", enable_tensor_parallelism: bool = True, early_stop: bool = True):
"""check gemini plugin over model zoo """check gemini plugin over model zoo
Args: Args:
@ -116,7 +122,12 @@ def check_gemini_plugin(subset: str, init_method: str = "none", early_stop: bool
"torchvision_efficientnet_v2_s", "torchvision_efficientnet_v2_s",
]: ]:
continue continue
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn)
# TODO debug blip2 when using tp, something wrong with shift_logits's shape
if "transformers_blip2" in name:
enable_tensor_parallelism = False
err = run_fn(init_method, model_fn, data_gen_fn, output_transform_fn, enable_tensor_parallelism)
torch.cuda.empty_cache() torch.cuda.empty_cache()
if err is None: if err is None:
passed_models.append(name) passed_models.append(name)

View File

@ -37,17 +37,20 @@ OPTIM_PLACEMENT_CONFIGS = [
@parameterize("placement_config", MODEL_PLACEMENT_CONFIGS) @parameterize("placement_config", MODEL_PLACEMENT_CONFIGS)
@parameterize("model_name", ["transformers_bert_for_sequence_classification"]) @parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_safetensors", [False, True]) @parameterize("use_safetensors", [False, True])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool): @parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: bool, enable_tensor_parallelism: bool, tp_size: int):
from transformers import BertForSequenceClassification from transformers import BertForSequenceClassification
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
bert_model = model_fn() bert_model = model_fn()
enable_all_optimization = True if enable_tensor_parallelism else False
with shared_tempdir() as tempdir: with shared_tempdir() as tempdir:
pretrained_path = os.path.join(tempdir, "pretrained") pretrained_path = os.path.join(tempdir, "pretrained")
bert_model.config.save_pretrained(save_directory=pretrained_path) bert_model.config.save_pretrained(save_directory=pretrained_path)
plugin = GeminiPlugin(**placement_config) plugin = GeminiPlugin(**placement_config, enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
bert_model, _, _, _, _ = booster.boost(bert_model) bert_model, _, _, _, _ = booster.boost(bert_model)
model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2 model_size = sum(p.numel() * p.element_size() for p in bert_model.parameters()) / 1024**2
@ -63,13 +66,16 @@ def exam_state_dict_with_origin(placement_config, model_name, use_safetensors: b
@clear_cache_before_run() @clear_cache_before_run()
@parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS) @parameterize("placement_config", OPTIM_PLACEMENT_CONFIGS)
@parameterize("shard", [False, True]) @parameterize("shard", [True, False])
@parameterize("model_name", ["transformers_gpt"]) @parameterize("model_name", ["transformers_gpt"])
@parameterize("size_per_shard", [32]) @parameterize("size_per_shard", [32])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int): @parameterize("enable_tensor_parallelism", [True, False])
@parameterize("tp_size", [2])
def exam_state_dict(placement_config, shard: bool, model_name: str, size_per_shard: int, enable_tensor_parallelism: bool, tp_size: int):
(model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values())) (model_fn, data_gen_fn, output_transform_fn, _, _) = next(iter(model_zoo.get_sub_registry(model_name).values()))
criterion = lambda x: x.mean() criterion = lambda x: x.mean()
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14)) enable_all_optimization = True if enable_tensor_parallelism else False
plugin = GeminiPlugin(**placement_config, precision="fp16", initial_scale=(2**14), enable_tensor_parallelism=enable_tensor_parallelism, tp_size=tp_size, enable_all_optimization=enable_all_optimization)
booster = Booster(plugin=plugin) booster = Booster(plugin=plugin)
model = model_fn() model = model_fn()
@ -148,7 +154,7 @@ def run_dist(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
@pytest.mark.parametrize("world_size", [2]) @pytest.mark.parametrize("world_size", [4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
def test_gemini_ckpIO(world_size): def test_gemini_ckpIO(world_size):
spawn(run_dist, world_size) spawn(run_dist, world_size)