[shardformer] refactor embedding resize (#5603)

* [branch rebase] rebase main to Feature/resize_embedding (#5554)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [CI] run pre-commit (#5577)

* fix

* [release] update version (#5411)

* [hotfix] fix typo s/keywrods/keywords etc. (#5429)

* [devops] fix compatibility (#5444)

* [devops] fix compatibility

* [hotfix] update compatibility test on pr

* [devops] fix compatibility

* [devops] record duration during comp test

* [test] decrease test duration

* fix falcon

* [shardformer] fix gathering output when using tensor parallelism (#5431)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* [doc] release Open-Sora 1.0 with model weights (#5468)

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] release Open-Sora 1.0 with model weights

* [doc] update open-sora demo (#5479)

* [doc] update open-sora demo

* [doc] update open-sora demo

* [doc] update open-sora demo

* [example] add grok-1 inference (#5485)

* [misc] add submodule

* remove submodule

* [example] support grok-1 tp inference

* [example] add grok-1 inference script

* [example] refactor code

* [example] add grok-1 readme

* [exmaple] add test ci

* [exmaple] update readme

* run pre-commit

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>

* [rebase] rebase main to resize-embedding (#5581)

* [release] grok-1 314b inference (#5490)

* [release] grok-1 inference

* [release] grok-1 inference

* [release] grok-1 inference

* [example] update Grok-1 inference (#5495)

* revise grok-1 example

* remove unused arg in scripts

* prevent re-installing torch

* update readme

* revert modifying colossalai requirements

* add perf

* trivial

* add tokenizer url

* [hotfix] set return_outputs=False in examples and polish code (#5404)

* fix: simplify merge_batch

* fix: use return_outputs=False to eliminate extra memory consumption

* feat: add return_outputs warning

* style: remove `return_outputs=False` as it is the default value

* [release] grok-1 inference benchmark (#5500)

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [release] grok-1 inference benchmark

* [shardformer]Fix lm parallel. (#5480)

* fix

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* fix lm forward distribution

* fix

* test ci

* fix

* [fix] fix grok-1 example typo (#5506)

* [devops] fix example test ci (#5504)

* Fix ColoTensorSpec for py11 (#5440)

* fixed layout converter caching and updated tester

* Empty-Commit

* [shardformer] update colo attention to support custom mask (#5510)

* [feature] refactor colo attention (#5462)

* [extension] update api

* [feature] add colo attention

* [feature] update sdpa

* [feature] update npu attention

* [feature] update flash-attn

* [test] add flash attn test

* [test] update flash attn test

* [shardformer] update modeling to fit colo attention (#5465)

* [misc] refactor folder structure

* [shardformer] update llama flash-attn

* [shardformer] fix llama policy

* [devops] update tensornvme install

* [test] update llama test

* [shardformer] update colo attn kernel dispatch

* [shardformer] update blip2

* [shardformer] update chatglm

* [shardformer] update gpt2

* [shardformer] update gptj

* [shardformer] update opt

* [shardformer] update vit

* [shardformer] update colo attention mask prep

* [shardformer] update whisper

* [test] fix shardformer tests (#5514)

* [test] fix shardformer tests

* [test] fix shardformer tests

* [format] applied code formatting on changed files in pull request 5510 (#5517)

Co-authored-by: github-actions <github-actions@github.com>

* [shardformer] fix pipeline forward error if custom layer distribution is used (#5189)

* Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution

* Change static methods for t5 layer distribution to member functions

* Change static methods for whisper layer distribution to member functions

* Replace whisper policy usage with self one

* Fix test case to use non-static layer distribution methods

* fix: fix typo

---------

Co-authored-by: Wenhao Chen <cwher@outlook.com>

* [Fix] Grok-1 use tokenizer from the same pretrained path (#5532)

* [fix] use tokenizer from the same pretrained path

* trust remote code

* [ColossalChat] Update RLHF V2 (#5286)

* Add dpo. Fix sft, ppo, lora. Refactor all

* fix and tested ppo

* 2 nd round refactor

* add ci tests

* fix ci

* fix ci

* fix readme, style

* fix readme style

* fix style, fix benchmark

* reproduce benchmark result, remove useless files

* rename to ColossalChat

* use new image

* fix ci workflow

* fix ci

* use local model/tokenizer for ci tests

* fix ci

* fix ci

* fix ci

* fix ci timeout

* fix rm progress bar. fix ci timeout

* fix ci

* fix ci typo

* remove 3d plugin from ci temporary

* test environment

* cannot save optimizer

* support chat template

* fix readme

* fix path

* test ci locally

* restore build_or_pr

* fix ci data path

* fix benchmark

* fix ci, move ci tests to 3080, disable fast tokenizer

* move ci to 85

* support flash attention 2

* add all-in-one data preparation script. Fix colossal-llama2-chat chat template

* add hardware requirements

* move ci test data

* fix save_model, add unwrap

* fix missing bos

* fix missing bos; support grad accumulation with gemini

* fix ci

* fix ci

* fix ci

* fix llama2 chat template config

* debug sft

* debug sft

* fix colossalai version requirement

* fix ci

* add sanity check to prevent NaN loss

* fix requirements

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* add dummy data generation script

* update readme

* update readme

* update readme and ignore

* fix logger bug

* support parallel_output

* modify data preparation logic

* fix tokenization

* update lr

* fix inference

* run pre-commit

---------

Co-authored-by: Tong Li <tong.li352711588@gmail.com>

* [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508)

* feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig`

* feat: apply `GradientCheckpointConfig` to policy and llama_forward

* feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager

* fix: add optional args for `distribute_layer` and `get_stage_index`

* fix: fix changed API calls

* test: update llama tests

* style: polish `GradientCheckpointConfig`

* fix: fix pipeline utils tests

* fix incorrect sharding without zero (#5545)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [shardformer] Sequence Parallelism Optimization (#5533)

* sequence parallel optimization

* validate sequence parallel in llama (code to be polished)

* shardformer api writing

* integrate sequence parallel in ShardFormer

* fix pp bugs and sp bugs for LlaMa model

* integrating ring-based sequence parallelism into ShardFormer

* [sequence parallelism]: Add fused megatron function

* integrating ring-based sequence parallelism into ShardFormer

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* fix bugs when useing sp and flashattention together

* fix operation function name

* support flash attention for ulysses-style sp

* clarify sp process group

* fix compatibility bugs in moe plugin

* fix fused linear bugs

* fix linear layer test

* support gpt model all-to-all sp

* modify shard data dimension (meant to be dim=-1)

* support megtron-style sp and distributed attn for llama model

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* finish sp mode 3 support for gpt

* using all_to_all_single when batch size is 1

* support mode 2 sp in gpt2 (#5)

* [shardformer] add megatron sp to llama

* support llama7B 128k with distributed attention

* [shardformer] robustness enhancement

* add block attn

* sp mode 1: keep input as a complete sequence

* fix sp compatability

* refactor ring implementation

* support mode 2 sp in gpt2

* polish code

* enable distributed attn mask when using sp mode 2 and 3 in llama

* automatically enable flash attn when using sp mode 2 and 3 in llama

* inplace attn mask

* add zero2 support for sequence parallel

* polish code

* fix bugs

* fix gemini checkpoint io

* loose tensor checking atol and rtol

* add comment

* fix llama layernorm grad

* fix zero grad

* fix zero grad

* fix conflict

* update split and gather auto grad func

* sequence parallel: inside text split (#6)

* polish code (part 1)

* polish code (part 2)

* polish code (part 2.5)

* polish code (part 3)

* sequence parallel: inside text split

* miscellaneous minor fixes

* polish code

* fix ulysses style ZeRO

* sequence parallel: inside text split

* miscellaneous minor fixes

* disaggregate sp group and dp group for  sp

* fix llama and gpt sp

* polish code

* move ulysses grad sync to ddp (#9)

* remove zero_stage and unbind the grad sync for alltoall sp

* add 2d group creation test

* move ulysses grad sync to ddp

* add 2d group creation test

* remove useless code

* change shard config not to enable sp when enable_all_optimizations

* add sp warnings for several model

* remove useless code

---------

Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>

* [hotfix] quick fixes to make legacy tutorials runnable (#5559)

Co-authored-by: Edenzzzz <wtan45@wisc.edu>

* [fix] fix typo s/muiti-node /multi-node etc. (#5448)

* [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548)

* [devops] remove post commit ci (#5566)

* [devops] remove post commit ci

* [misc] run pre-commit on all files

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

---------

Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [shardformer]enable padding vocabulary size. (#5489)

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

* fix

fix

fix

* fix gather output

* fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* revert

* padding vocab

* padding vocabe

* fix

* fix

* fxi

* test ci

* fix

fix

fix

fix

* fix

fix

* fix

* fix

* Update hybrid_parallel_plugin.py

fix

fix

fix

* fix

fix

* fix

fix

* fix

* resolve super init

resolve super init

resolve super init

resolve super init

* resolve comments

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* vocab checkpointio

* padding vocab_size when using pipeline parallellism

padding vocab_size when using pipeline parallellism

fix

fix

* fix

fix

fix

* fix

* fix

fix resize embedding

fix resize embedding

* fix resize embedding

fix

* revert

* revert

* padding vocab

* fix

* fix

fix

* fix

fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix ci

* fix

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

* cherry-pick

* revert moe modify

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix

fix

fix

fix

fix

fix

fix

fix

* resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

* ptensor

ptensor

resolve comments

fix

fix

fix

fix

fix

resolve comments

resolve comments

resolve comments

resolve comments

resolve comments

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix rebase

* fix rebase

---------

Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: digger yu <digger-yu@outlook.com>
Co-authored-by: binmakeswell <binmakeswell@gmail.com>
Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com>
Co-authored-by: Wenhao Chen <cwher@outlook.com>
Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: github-actions <github-actions@github.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com>
Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
pull/5619/head
flybird11111 2024-04-18 16:10:18 +08:00 committed by GitHub
parent 3788fefc7a
commit a0ad587c24
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 1352 additions and 289 deletions

View File

@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
def get_param_info(optim: Optimizer):
# Get a backup of necessary information of parameters for future use, which includes:
# 1. A mapping from integer param_id to param32 shape.
if optim is None:
return {}
param_info = {"id2shape": {}}
start_index = 0
for group in optim.param_groups:
for param_id, param in enumerate(group["params"], start_index):
@ -527,7 +527,7 @@ class GeminiPlugin(DPPluginBase):
dataloader: Optional[DataLoader] = None,
lr_scheduler: Optional[LRScheduler] = None,
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
optimizer_params_info = get_param_info(optimizer)
params_info = get_param_info(optimizer)
if not isinstance(model, ModelWrapper):
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
@ -558,7 +558,7 @@ class GeminiPlugin(DPPluginBase):
**self.zero_optim_config,
**self.optim_kwargs,
tp_group=self.tp_group,
optimizer_params_info=optimizer_params_info,
params_info=params_info,
verbose=self.verbose,
)

View File

@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer):
if optim is None:
return {}
param_info = {
"param_groups": [],
"param2id": {},
"id2param": {},
"param2shape": {},
}
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
start_index = 0
for group in optim.param_groups:
packed_group = {k: v for k, v in group.items() if k != "params"}
@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
"""
def __init__(
@ -989,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase):
num_model_chunks: int = 1,
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
enable_metadata_cache: bool = True,
make_vocab_size_divisible_by: int = 64,
) -> None:
super().__init__()
assert (
@ -1095,6 +1093,7 @@ class HybridParallelPlugin(PipelinePluginBase):
sequence_parallelism_mode=sequence_parallelism_mode,
enable_sequence_overlap=enable_sequence_overlap,
parallel_output=parallel_output,
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
gradient_checkpoint_config=gradient_checkpoint_config,
)
self.amp_config = dict(

View File

@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import get_current_device
from .general_checkpoint_io import GeneralCheckpointIO
@ -32,6 +38,7 @@ from .utils import (
save_param_groups,
save_state_dict,
save_state_dict_shards,
search_padding_dim,
search_tp_partition_dim,
sharded_optimizer_loading_epilogue,
)
@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if param is None:
continue
# Gather tensor pieces when using tensor parallel.
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
param_ = gather_distributed_param(param, keep_vars=False)
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
if block is not None:
@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
# When pipeline is used, each stage produces its own shard files and index files.
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
final_index_file_path = copy.deepcopy(save_index_file)
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
use_safetensors=use_safetensors,
use_pp_format=True,
)
if control_saving:
assert (
self.dp_rank == 0 and self.tp_rank == 0
@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
dist.all_gather(gather_tensor, v, group=tp_group)
v = torch.cat(gather_tensor, dim=partition_dim)
padding_dim = search_padding_dim(v.shape, original_shape)
if padding_dim is not None:
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
v = to_unpadded_tensor(v)
state_[k] = v.detach().clone().to(device)
return state_
@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
if isinstance(v, torch.Tensor) and k != "step":
# Shard state along tensor parallel group.
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
global_shape = current_shape
if partition_dim is not None:
# pad embedding params
global_shape = (
*current_shape[:partition_dim],
current_shape[partition_dim] * self.tp_size,
*current_shape[partition_dim + 1 :],
)
padding_dim = search_padding_dim(global_shape, original_shape)
if padding_dim is not None:
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)
if partition_dim is not None:
slice_size = current_shape[partition_dim]
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]

View File

@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
return partition_dim
def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
padding_dim = None
for dim, length in enumerate(global_shape):
if length > original_shape[dim]:
padding_dim = dim
break
return padding_dim
# ======================================
# Helper classes and functions for saving shard file
# ======================================

View File

@ -1,8 +1,8 @@
from ._operation import all_to_all_comm
from .attn import AttnMaskType, ColoAttention
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
from .embedding import Embedding1D, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
from .loss import cross_entropy_1d
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
from .parallel_module import ParallelModule
@ -25,6 +25,9 @@ __all__ = [
"FusedRMSNorm",
"FusedLinear1D_Col",
"ParallelModule",
"PaddingEmbedding",
"PaddingLMHead",
"VocabParallelLMHead1D",
"AttnMaskType",
"ColoAttention",
"all_to_all_comm",

View File

@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import (
)
from ._operation import gather_forward_split_backward, reduce_forward
from .parallel_module import ParallelModule
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
class Embedding1D(ParallelModule):
@ -161,7 +161,80 @@ class Embedding1D(ParallelModule):
return output_parallel
class VocabParallelEmbedding1D(ParallelModule):
class PaddingEmbedding(PaddingParallelModule):
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[nn.Parameter] = None,
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
self.embed_kwargs = kwargs
self.padding_idx = padding_idx
if num_embeddings % make_vocab_size_divisible_by != 0:
self.num_embeddings = (
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
)
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
super().__init__(self.num_embeddings, num_embeddings, weight)
if weight is None:
self.reset_parameters()
def reset_parameters(self) -> None:
init.normal_(self.weight)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input: Tensor) -> Tensor:
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
LazyInitContext.materialize(module)
# get the origin attributes
num_embeddings = module.num_embeddings
embedding_dim = module.embedding_dim
padding_idx = module.padding_idx
device = module.weight.device
# create the parallel module
padding_embedding = PaddingEmbedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
padding_idx=padding_idx,
device=device,
weight=module.weight,
*args,
**kwargs,
)
return padding_embedding
class VocabParallelEmbedding1D(PaddingParallelModule):
r"""Embedding parallelized in the vocabulary dimension.
Args:
@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule):
process_group: ProcessGroup = None,
weight: Optional[nn.Parameter] = None,
weight_initializer: Callable = init.normal_(),
make_vocab_size_divisible_by: int = 64,
*args,
**kwargs,
):
super().__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.embed_args = args
@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule):
tensor_parallel_size = dist.get_world_size(group=process_group)
tensor_parallel_rank = dist.get_rank(group=process_group)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.num_embeddings = self.num_embeddings_per_partition
# generate weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
# calculate new padding size
multiple = make_vocab_size_divisible_by * tensor_parallel_size
if num_embeddings % multiple != 0:
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
# resize vocabulary size
super().__init__(self.num_embeddings, num_embeddings, weight)
# deal with tensor parallelism
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule):
seed = torch.random.initial_seed()
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
# parameter
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
) -> ParallelModule:
) -> PaddingParallelModule:
r"""
Convert a native pytorch embedding module to a parallel module.
"""
@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule):
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
)
# Mask the output embedding.
embedding_output = output_parallel.clone()
embedding_output[input_mask, :] = 0.0

View File

@ -32,7 +32,7 @@ from ._operation import (
reducescatter_forward_gather_backward,
split_forward_gather_backward,
)
from .parallel_module import ParallelModule
from .parallel_module import PaddingParallelModule, ParallelModule
from .utils import create_randomizer_with_offset
__all__ = ["Linear1D_Col", "Linear1D_Row"]
@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule):
bias_: Optional[Parameter] = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
**kwargs,
):
super().__init__()
super().__init__(weight=weight, bias_=bias_, **kwargs)
# Keep input parameters
self.in_features = in_features
@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule):
else:
weight.data = weight.data.to(device=device, dtype=dtype)
self.weight = weight
if not is_distributed_tensor(self.weight):
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
sharded_tensor_to_existing_param(sharded_weight, self.weight)
@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule):
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> ParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule):
process_group=process_group,
weight=module.weight,
bias_=module.bias,
*args,
**kwargs,
)
@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule):
return output
else:
return output, self.bias
class PaddingLMHead(PaddingParallelModule):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
):
# Keep input parameters
self.in_features = in_features
self.out_features = out_features
if out_features % make_vocab_size_divisible_by != 0:
self.out_features = (
out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by)
)
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
else:
weight.data = weight.data.to(device=device, dtype=dtype)
if bias:
if bias_ is None:
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_.data = bias_.data.to(device=device, dtype=dtype)
else:
bias_ = None
# resize embeddings
super().__init__(self.out_features, out_features, weight, bias_)
if weight is None:
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.out_features
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
# ensure only one process group is passed
lm_head_linear = PaddingLMHead(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input: Tensor) -> Tensor:
output = F.linear(input, self.weight, self.bias)
output = output[..., : self.old_num_embeddings]
return output
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
r"""Linear layer with column parallelism.
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
its second dimension as :math:`A = [A_1, ..., A_p]`.
Args:
in_features (int): size of each input sample.
out_features (int): size of each output sample.
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
device (`torch.device`): The device of parameters, defaults to None.
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
gather_output (bool, optional): If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
which is preserved for kernel fusion, defaults to False
weight_initializer (`typing.Callable`):
The initializer of weight, defaults to kaiming uniform initializer.
bias_initializer (`typing.Callable`):
The initializer of bias, defaults to xavier uniform initializer.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
weight: Optional[Parameter] = None,
bias_: Optional[Parameter] = None,
make_vocab_size_divisible_by: int = 64,
**kwargs,
):
# create weight and bias
if weight is None:
factory_kwargs = {"device": device, "dtype": dtype}
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
if bias:
if bias_ is None:
bias_ = Parameter(torch.empty(out_features, **factory_kwargs))
else:
bias_ = None
# calculate new vocab size
self.tensor_parallel_size = dist.get_world_size(group=process_group)
new_out_features = out_features
multiple = make_vocab_size_divisible_by * self.tensor_parallel_size
if out_features % multiple != 0:
new_out_features = out_features + multiple - (out_features % multiple)
super().__init__(
in_features=in_features,
out_features=new_out_features,
bias=bias,
device=device,
process_group=process_group,
weight=weight,
bias_=bias_,
**kwargs,
new_num_embeddings=new_out_features,
old_num_embeddings=out_features,
)
# get the length of valid embeddings
tp_rank = dist.get_rank(process_group)
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
self.num_valid_embeddings_local = partition_size
elif self.old_num_embeddings >= tp_rank * partition_size:
self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
else:
self.num_valid_embeddings_local = 0
@staticmethod
def from_native_module(
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
) -> PaddingParallelModule:
r"""
Convert a native PyTorch linear layer to a parallelized linear layer.
"""
LazyInitContext.materialize(module)
# get the attributes
in_features = module.in_features
out_features = module.out_features
bias = module.bias is not None
device = module.weight.device
lm_head_linear = VocabParallelLMHead1D(
in_features=in_features,
out_features=out_features,
bias=bias,
device=device,
process_group=process_group,
weight=module.weight,
bias_=module.bias,
**kwargs,
)
return lm_head_linear
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
# get forward output
if self.skip_bias_add:
output, bias = super().forward(input_)
else:
output = super().forward(input_)
# delete the padding of output
if self.gather_output:
output = output[..., : self.old_num_embeddings]
else:
output = output[..., : self.num_valid_embeddings_local]
# return
if self.skip_bias_add:
return output, bias
return output

View File

@ -15,7 +15,14 @@ class DistCrossEntropy(Function):
"""
@staticmethod
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
@ -41,15 +48,21 @@ class DistCrossEntropy(Function):
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
global_vocab_size = partition_vocab_size * world_size
if vocab_size == None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
@ -57,7 +70,8 @@ class DistCrossEntropy(Function):
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
logits_2d = vocab_logits.view(-1, partition_vocab_size)
self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero
@ -104,10 +118,14 @@ class DistCrossEntropy(Function):
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None
return grad_logits, None, None, None, None
def cross_entropy_1d(
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
vocab_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
process_group: ProcessGroup = None,
vocab_size: int = None,
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)

View File

@ -3,7 +3,7 @@
import itertools
from abc import ABC, abstractmethod
from typing import List, Union
from typing import List, Optional, Union
import torch
import torch.nn as nn
@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import (
is_distributed_tensor,
sharded_tensor_to_param,
)
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
__all__ = ["ParallelModule"]
class ParallelModule(nn.Module, ABC):
def __init__(self, **kwargs):
super().__init__()
@abstractmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC):
"""
for name, param in self._parameters.items():
if param is not None:
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC):
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
class PaddingParallelModule(ParallelModule):
def __init__(
self,
new_num_embeddings: int,
old_num_embeddings: int,
weight: Optional[nn.Parameter],
bias_: Optional[nn.Parameter] = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
self.new_num_embeddings = new_num_embeddings
self.old_num_embeddings = old_num_embeddings
self.weight = weight
self.bias = bias_
if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
self.resize_embedding_weight()
if self.bias is not None and not (
is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings
):
self.resize_embedding_bias()
@abstractmethod
def from_native_module(
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
) -> "PaddingParallelModule":
"""
Convert a native PyTorch module to a parallelized module.
Args:
module (nn.Module): the module to be converted.
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
If this is a list, the process group at the ith index of the list will correspond to the process group
in the ith axis of the device mesh. Defaults to None, which means the global process group.
"""
raise NotImplementedError
def _save_to_state_dict(self, destination, prefix, keep_vars):
r"""Saves module state to `destination` dictionary, containing a state
of the module, but not its descendants. This is called on every
submodule in :meth:`~torch.nn.Module.state_dict`.
In rare cases, subclasses can achieve class-specific behavior by
overriding this method with custom logic.
Args:
destination (dict): a dict where state will be stored
prefix (str): the prefix for parameters and buffers used in this
module
"""
for name, param in self._parameters.items():
if param is not None:
param = gather_distributed_param(param, keep_vars=keep_vars)
if is_padded_tensor(param):
param = to_unpadded_tensor(param)
destination[prefix + name] = param.data
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
destination[extra_state_key] = self.get_extra_state()
def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
):
r"""Copies parameters and buffers from :attr:`state_dict` into only
this module, but not its descendants. This is called on every submodule
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
For state dicts without metadata, :attr:`local_metadata` is empty.
Subclasses can achieve class-specific backward compatible loading using
the version number at `local_metadata.get("version", None)`.
.. note::
:attr:`state_dict` is not the same object as the input
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
it can be modified.
Args:
state_dict (dict): a dict containing parameters and
persistent buffers.
prefix (str): the prefix for parameters and buffers used in this
module
local_metadata (dict): a dict containing the metadata for this module.
See
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` with :attr:`prefix` match the names of
parameters and buffers in this module
missing_keys (list of str): if ``strict=True``, add missing keys to
this list
unexpected_keys (list of str): if ``strict=True``, add unexpected
keys to this list
error_msgs (list of str): error messages should be added to this
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
if key in state_dict:
input_param = state_dict[key]
if not torch.overrides.is_tensor_like(input_param):
error_msgs.append(
'While copying the parameter named "{}", '
"expected torch.Tensor or Tensor-like object from checkpoint but "
"received {}".format(key, type(input_param))
)
continue
if is_padded_tensor(param):
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
if is_distributed_tensor(param):
# shard the input param
device_mesh = get_device_mesh(param)
sharding_spec = get_sharding_spec(param)
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
input_param = sharded_tensor_to_param(sharded_tensor)
elif is_customized_distributed_tensor(param):
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
# This is used to avoid copying uninitialized parameters into
# non-lazy modules, since they dont have the hook to do the checks
# in such case, it will error when accessing the .shape attribute.
is_param_lazy = torch.nn.parameter.is_lazy(param)
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
if not is_param_lazy and input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
"size mismatch for {}: copying a param with shape {} from checkpoint, "
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
)
continue
try:
with torch.no_grad():
param.copy_(input_param)
except Exception as ex:
error_msgs.append(
'While copying the parameter named "{}", '
"whose dimensions in the model are {} and "
"whose dimensions in the checkpoint are {}, "
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
)
elif strict:
missing_keys.append(key)
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
if extra_state_key in state_dict:
self.set_extra_state(state_dict[extra_state_key])
elif strict:
missing_keys.append(extra_state_key)
elif strict and (extra_state_key in state_dict):
unexpected_keys.append(extra_state_key)
if strict:
for key in state_dict.keys():
if key.startswith(prefix) and key != extra_state_key:
input_name = key[len(prefix) :]
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
if input_name not in self._modules and input_name not in local_state:
unexpected_keys.append(key)
def resize_embedding_weight(self):
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
def resize_embedding_bias(self):
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)

View File

@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar
from colossalai.shardformer.shard import ShardConfig
from ..layer import cross_entropy_1d
from ..layer._operation import gather_forward_split_backward
logger = logging.get_logger(__name__)
@ -397,13 +396,11 @@ class GPT2PipelineForwards:
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
loss = loss_fct(shift_logits, shift_labels)
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
@ -1301,12 +1298,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
shift_labels = shift_labels.view(-1)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not shard_config.parallel_output:
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output

View File

@ -316,7 +316,10 @@ class LlamaPipelineForwards:
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)
@ -735,11 +738,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
new_vocab_size = logits.shape[-1]
shift_logits = shift_logits.view(-1, new_vocab_size)
loss = cross_entropy_1d(
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
)
if not return_dict:

View File

@ -195,3 +195,12 @@ class Policy(ABC):
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
"""
return []
def tie_weight_check(self):
input_embedding = self.model.get_input_embeddings()
output_embedding = self.model.get_output_embeddings()
return (
input_embedding is not None
and output_embedding is not None
and id(input_embedding.weight) == id(output_embedding.weight)
)

View File

@ -37,17 +37,7 @@ class BertPolicy(Policy):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -62,6 +52,13 @@ class BertPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@ -150,10 +147,6 @@ class BertPolicy(Policy):
policy[BertEmbeddings] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
@ -168,6 +161,18 @@ class BertPolicy(Policy):
target_key=BertModel,
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=embedding_cls,
)
],
policy=policy,
target_key=BertEmbeddings,
)
# optimization configuration
# Handle bert layer
self.append_or_create_submodule_replacement(
@ -237,8 +242,21 @@ class BertPolicy(Policy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="decoder",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
policy=base_policy,
target_key=BertLMPredictionHead,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="decoder",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=base_policy,
target_key=BertLMPredictionHead,

View File

@ -17,16 +17,7 @@ class BlipPolicy(Policy):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.qformer_config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -43,6 +34,13 @@ class BlipPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@ -202,22 +200,48 @@ class BlipPolicy(Policy):
],
)
policy[OPTForCausalLM] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
),
]
)
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="model.decoder.embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=OPTForCausalLM,
)
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
],
policy=policy,
target_key=OPTForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=OPTForCausalLM,
)
# optimization configuration
# Handle Blip2EncoderLayer layer
self.append_or_create_submodule_replacement(

View File

@ -35,16 +35,7 @@ class BloomPolicy(Policy):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -52,6 +43,13 @@ class BloomPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@ -112,12 +110,19 @@ class BloomPolicy(Policy):
method_replacement={
"build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
},
sub_module_replacement=[
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=BloomModel,
)
# optimization configuration
@ -282,7 +287,21 @@ class BloomForCausalLMPolicy(BloomPolicy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
),
),
policy=policy,
target_key=BloomForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
),
policy=policy,
target_key=BloomForCausalLM,

View File

@ -25,20 +25,12 @@ class ChatGLMPolicy(Policy):
pass
def preprocess(self):
# Resize embedding
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.padded_vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
if self.pipeline_stage_manager is not None:
# the batch_size_dim is bounded to Model
bsz_dim = 1
setattr(self.model, "batch_size_dim", bsz_dim)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
@ -46,6 +38,13 @@ class ChatGLMPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization:
if self.model.config.rmsnorm:
norm_cls = col_nn.FusedRMSNorm
@ -68,16 +67,6 @@ class ChatGLMPolicy(Policy):
sp_partial_derived = sp_mode == "split_gather"
if self.shard_config.enable_tensor_parallelism:
policy[ChatGLMModel] = ModulePolicyDescription(
attribute_replacement={},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
],
)
policy[GLMBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
@ -114,6 +103,19 @@ class ChatGLMPolicy(Policy):
),
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="embedding.word_embeddings",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=ChatGLMModel,
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=[

View File

@ -32,16 +32,7 @@ class FalconPolicy(Policy):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -58,6 +49,14 @@ class FalconPolicy(Policy):
warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_tensor_parallelism:
attn_attribute_replacement = {
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
@ -98,12 +97,19 @@ class FalconPolicy(Policy):
method_replacement={
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
},
sub_module_replacement=[
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
],
policy=policy,
target_key=FalconModel,
)
# optimization configuration
@ -232,11 +238,26 @@ class FalconForCausalLMPolicy(FalconPolicy):
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
suffix="lm_head",
target_module=col_nn.VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
),
),
policy=policy,
target_key=FalconForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
),
policy=policy,
target_key=FalconForCausalLM,
)
if self.pipeline_stage_manager:
self.set_pipeline_forward(
model_cls=FalconForCausalLM,

View File

@ -34,12 +34,7 @@ class GPT2Policy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -47,6 +42,13 @@ class GPT2Policy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@ -73,10 +75,6 @@ class GPT2Policy(Policy):
if self.shard_config.enable_tensor_parallelism:
policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
@ -137,6 +135,17 @@ class GPT2Policy(Policy):
),
],
)
if embedding_cls is not None:
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=GPT2Model,
)
# optimization configuration
self.append_or_create_submodule_replacement(
@ -298,8 +307,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": False,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
)
],
)
@ -308,7 +320,19 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
addon_module[GPT2LMHeadModel].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
module_policy.update(addon_module)
else:
addon_module = {
GPT2LMHeadModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(
@ -353,13 +377,28 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
)
]
)
}
module_policy.update(addon_module)
else:
addon_module = {
GPT2DoubleHeadsModel: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
]
)
}
module_policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(

View File

@ -29,22 +29,21 @@ class GPTJPolicy(Policy):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
@ -54,10 +53,6 @@ class GPTJPolicy(Policy):
if self.shard_config.enable_tensor_parallelism:
policy[GPTJModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="wte",
target_module=col_nn.VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="drop",
target_module=col_nn.DropoutForParallelInput,
@ -126,6 +121,17 @@ class GPTJPolicy(Policy):
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="wte",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=GPTJModel,
)
# optimization configuration
if self.shard_config.enable_fused_normalization:
self.append_or_create_submodule_replacement(
@ -255,13 +261,28 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
)
]
)
}
policy.update(addon_module)
else:
addon_module = {
GPTJForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
]
)
}
policy.update(addon_module)
if self.pipeline_stage_manager is not None:
self.set_pipeline_forward(

View File

@ -6,7 +6,16 @@ import torch.nn as nn
from torch import Tensor
from torch.nn import Module
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.llama import (
LlamaPipelineForwards,
@ -26,15 +35,7 @@ class LlamaPolicy(Policy):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
@ -42,6 +43,13 @@ class LlamaPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
@ -167,10 +175,12 @@ class LlamaPolicy(Policy):
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=LlamaModel,
@ -327,8 +337,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs={"gather_output": not self.shard_config.parallel_output},
target_module=VocabParallelLMHead1D,
kwargs={
"gather_output": not self.shard_config.parallel_output,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
)
],
)
@ -337,7 +350,19 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
new_item[LlamaForCausalLM].method_replacement = {
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
}
policy.update(new_item)
else:
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
)
],
)
}
policy.update(new_item)
if self.pipeline_stage_manager:
# set None as default

View File

@ -3,7 +3,15 @@ from typing import Dict, Union
import torch.nn as nn
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from ..modeling.mistral import get_mistral_flash_attention_forward
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
@ -16,15 +24,7 @@ class MistralPolicy(Policy):
pass
def preprocess(self):
if self.shard_config.enable_tensor_parallelism:
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
@ -32,6 +32,13 @@ class MistralPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_sequence_parallelism:
self.shard_config.enable_sequence_parallelism = False
warnings.warn(
@ -80,10 +87,12 @@ class MistralPolicy(Policy):
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=MistralModel,
@ -146,6 +155,8 @@ class MistralForCausalLMPolicy(MistralPolicy):
from transformers import MistralForCausalLM
policy = super().module_policy()
if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.")
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
@ -153,16 +164,30 @@ class MistralForCausalLMPolicy(MistralPolicy):
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=True,
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
),
)
]
)
}
else:
new_item = {
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
)
]
)
}
if self.pipeline_stage_manager:
warnings.warn("Mistral doesn't support pipeline parallelism now.")
policy.update(new_item)
policy.update(new_item)
return policy

View File

@ -5,7 +5,16 @@ from typing import Callable, Dict, List
import torch.nn as nn
from torch import Tensor, nn
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import (
FusedLayerNorm,
LayerNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from .._utils import getattr_
from ..modeling.jit import get_jit_fused_dropout_add_func
@ -41,16 +50,7 @@ class OPTPolicy(Policy):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -58,6 +58,13 @@ class OPTPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedLayerNorm
else:
@ -68,14 +75,6 @@ class OPTPolicy(Policy):
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
if self.shard_config.enable_tensor_parallelism:
policy[OPTDecoder] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
)
]
)
policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
@ -114,6 +113,17 @@ class OPTPolicy(Policy):
],
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=OPTDecoder,
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@ -253,8 +263,20 @@ class OPTForCausalLMPolicy(OPTPolicy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True),
target_module=VocabParallelLMHead1D,
kwargs=dict(
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
),
),
policy=policy,
target_key=OPTForCausalLM,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
),
policy=policy,
target_key=OPTForCausalLM,

View File

@ -13,8 +13,11 @@ from colossalai.shardformer.layer import (
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
PaddingEmbedding,
PaddingLMHead,
RMSNorm,
VocabParallelEmbedding1D,
VocabParallelLMHead1D,
)
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
@ -36,16 +39,7 @@ class T5BasePolicy(Policy):
pass
def preprocess(self):
# reshape the embedding layer
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
if self.shard_config.enable_tensor_parallelism:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -61,6 +55,13 @@ class T5BasePolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = FusedRMSNorm
else:
@ -77,10 +78,6 @@ class T5BasePolicy(Policy):
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=VocabParallelEmbedding1D,
),
]
)
policy[T5LayerSelfAttention] = ModulePolicyDescription(
@ -176,6 +173,17 @@ class T5BasePolicy(Policy):
]
)
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5Stack,
)
# optimization configuration
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
@ -370,11 +378,19 @@ class T5ModelPolicy(T5BasePolicy):
policy = super().module_policy()
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5Model,
@ -406,17 +422,44 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
policy = super().module_policy()
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="shared",
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5ForConditionalGeneration,
)
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
),
],
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
policy=policy,
target_key=T5ForConditionalGeneration,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="lm_head",
target_module=PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5ForConditionalGeneration,
)
@ -467,11 +510,19 @@ class T5EncoderPolicy(T5BasePolicy):
policy = super().module_policy()
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = PaddingEmbedding
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=policy,
target_key=T5EncoderModel,

View File

@ -45,11 +45,7 @@ class WhisperPolicy(Policy):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.tie_weight = self.tie_weight_check()
return self.model
def module_policy(self):
@ -63,6 +59,13 @@ class WhisperPolicy(Policy):
policy = {}
embedding_cls = None
if self.shard_config.enable_tensor_parallelism:
embedding_cls = col_nn.VocabParallelEmbedding1D
else:
if self.tie_weight:
embedding_cls = col_nn.PaddingEmbedding
if self.shard_config.enable_fused_normalization:
norm_cls = col_nn.FusedLayerNorm
else:
@ -167,13 +170,17 @@ class WhisperPolicy(Policy):
],
)
policy[WhisperDecoder] = ModulePolicyDescription(
sub_module_replacement=[
if embedding_cls is not None:
self.append_or_create_submodule_replacement(
description=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=col_nn.VocabParallelEmbedding1D,
target_module=embedding_cls,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
]
],
policy=policy,
target_key=WhisperDecoder,
)
# optimization configuration
@ -280,8 +287,21 @@ class WhisperPolicy(Policy):
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="proj_out",
target_module=col_nn.Linear1D_Col,
kwargs={"gather_output": True},
target_module=col_nn.VocabParallelLMHead1D,
kwargs={
"gather_output": True,
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
},
),
policy=base_policy,
target_key=WhisperForConditionalGeneration,
)
else:
self.append_or_create_submodule_replacement(
description=SubModuleReplacementDescription(
suffix="proj_out",
target_module=col_nn.PaddingLMHead,
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
),
policy=base_policy,
target_key=WhisperForConditionalGeneration,
@ -526,9 +546,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
# WhisperForAudioClassification
class WhisperForAudioClassificationPolicy(WhisperPolicy):
def preprocess(self):
return self.model
def module_policy(self):
from transformers import WhisperForAudioClassification

View File

@ -42,10 +42,9 @@ class ShardConfig:
sequence_parallelism_mode: str = None
enable_sequence_overlap: bool = False
parallel_output: bool = True
make_vocab_size_divisible_by: int = 64
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
# TODO padding vocab
# make_vocab_size_divisible_by: int = 128
# pipeline_parallel_size: int
# data_parallel_size: int
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']

View File

@ -10,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta
from colossalai.tensor.d_tensor.comm_spec import *
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.misc import LayoutException
from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
from .sharding_spec import ShardingSpec
@ -607,8 +608,18 @@ class LayoutConverter(metaclass=SingletonMeta):
[3.],
[3.]])
"""
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
target_tensor = tensor
for comm_spec in comm_action_sequence:
tensor = comm_spec.covert_spec_to_action(tensor)
tensor.dist_layout = target_layout
return tensor
target_tensor = comm_spec.covert_spec_to_action(target_tensor)
target_tensor.dist_layout = target_layout
# restore the padding information
if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor):
target_tensor = init_as_padded_tensor(
target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
)
return target_tensor

View File

@ -0,0 +1,3 @@
from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"]

View File

@ -0,0 +1,128 @@
import torch
def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor._unpad_detach = ptensor.detach
ptensor._unpad_clone = ptensor.clone
def new_detach(self):
t_ = self._unpad_detach()
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_
def new_clone(self, *args, **kwargs):
t_ = self._unpad_clone(*args, **kwargs)
t_._padding_dim = self._padding_dim
t_._origin_length = self._origin_length
t_._current_length = self._current_length
return t_
# bind the new methods to the tensor
ptensor.detach = new_detach.__get__(ptensor)
ptensor.clone = new_clone.__get__(ptensor)
return ptensor
def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
"""
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
Args:
tensor (torch.Tensor): The tensor to be hijacked.
Returns:
torch.Tensor: The hijacked tensor.
"""
ptensor.detach = ptensor._unpad_detach
ptensor.clone = ptensor._unpad_clone
delattr(ptensor, "_unpad_detach")
delattr(ptensor, "_unpad_clone")
return ptensor
def is_padded_tensor(tensor: torch.Tensor) -> bool:
"""
Check whether the given tensor is a padding tensor.
Args:
tensor (torch.Tensor): The tensor to be checked.
Returns:
bool: Whether the given tensor is a padding tensor.
"""
return hasattr(tensor, "_padding_dim")
def to_padded_tensor(
tensor: torch.Tensor,
current_length: int,
padding_dim: int,
) -> torch.Tensor:
assert (
padding_dim < tensor.dim()
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"
if is_padded_tensor(tensor):
return tensor
origin_length = tensor.shape[padding_dim]
padding_num = current_length - origin_length
padding_data = torch.zeros(
*tensor.shape[:padding_dim],
padding_num,
*tensor.shape[padding_dim + 1 :],
device=tensor.device,
dtype=tensor.dtype,
)
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length
_hijack_detach_and_clone(tensor)
return tensor
def to_unpadded_tensor(ptensor: torch.Tensor):
if not is_padded_tensor(ptensor):
return ptensor
unpad_slices = [slice(None)] * ptensor.dim()
unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length)
ptensor.data = ptensor.data[tuple(unpad_slices)]
delattr(ptensor, "_padding_dim")
delattr(ptensor, "_origin_length")
delattr(ptensor, "_current_length")
_hijack_back_detach_and_clone(ptensor)
return ptensor
def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
if is_padded_tensor(tensor):
return tensor
tensor._padding_dim = padding_dim
tensor._origin_length = origin_length
tensor._current_length = current_length
_hijack_detach_and_clone(tensor)
return tensor

View File

@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
rtol=rtol,
atol=atol,
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
dtype: {a.dtype} vs {b.dtype}",
dtype: {a.dtype} vs {b.dtype}",
)

View File

@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper):
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
)
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
if is_padded_tensor(tensor):
record_tensor = init_as_padded_tensor(
record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
)
record_tensor = to_unpadded_tensor(record_tensor)
assert tensor not in chunk_to_save_data
chunk_to_save_data[tensor] = record_tensor
@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper):
# deal with ddp ignored parameters
destination[prefix + name] = param if keep_vars else param.detach()
else:
if is_padded_tensor(p_mapping[param]):
p_mapping[param] = to_unpadded_tensor(p_mapping[param])
destination[prefix + name] = p_mapping[param]
del p_mapping
del param_to_save_data
@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper):
list, and will be reported together in
:meth:`~torch.nn.Module.load_state_dict`
"""
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper):
if state_key in state_dict:
input_param = state_dict[state_key]
global_shape = dest_tensor.shape
if source_device_mesh is not None and source_sharding_spec is not None:
global_shape = get_global_shape(dest_tensor)
if is_padded_tensor(dest_tensor):
padding_dim = dest_tensor._padding_dim
input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim)
if source_device_mesh is not None and source_sharding_spec is not None:
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
elif shard_fn is not None and gather_fn is not None:

View File

@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import (
distribute_tensor,
distribute_tensor_with_customization,
get_device_mesh,
get_global_shape,
get_sharding_spec,
init_as_dtensor,
init_tensor_as_customization_distributed,
is_customized_distributed_tensor,
is_distributed_tensor,
)
from colossalai.tensor.padded_tensor import (
init_as_padded_tensor,
is_padded_tensor,
to_padded_tensor,
to_unpadded_tensor,
)
from colossalai.utils import disposable, is_ddp_ignored
from .chunk import Chunk, ChunkManager
@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper):
max_norm: float = 0.0,
norm_type: float = 2.0,
tp_group: ProcessGroup = None,
optimizer_params_info=None,
params_info=None,
verbose: bool = False,
**defaults: Any,
):
@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper):
self.clipping_flag = max_norm > 0.0
self.max_norm = max_norm
self.tp_group = tp_group
self.optimizer_params_info = optimizer_params_info
self.params_info = params_info
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
self.verbose = verbose
@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper):
is_customized_distributed = is_customized_distributed_tensor(param)
shard_spec = get_sharding_spec(param) if is_dtensor else None
device_mesh = get_device_mesh(param) if is_dtensor else None
global_shape = self.optimizer_params_info["id2shape"][param_id]
global_shape = self.params_info["id2shape"][param_id]
# If the chunk is kept gathered,
# the parameters are treated the same as that of those in strict DDP during training.
@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper):
else:
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
if is_dtensor:
global_shape = get_global_shape(param)
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
state_tensor = init_as_dtensor(
state_tensor,
@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
collected_states[state_name] = state_tensor.reshape(global_shape)
state_tensor = state_tensor.reshape(global_shape)
if is_padded_tensor(param):
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)
collected_states[state_name] = state_tensor
return collected_states
# Check whether the param with given id is managed by current process.
@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper):
if state_tensor.numel() == param.numel():
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
if is_dtensor:
global_shape = get_global_shape(param)
state_tensor = state_tensor.to(param.device)
state_tensor = init_as_dtensor(
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper):
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
)
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
if is_padded_tensor(param):
state_tensor = init_as_padded_tensor(
state_tensor, param._current_length, param._origin_length, param._padding_dim
)
state_tensor = to_unpadded_tensor(state_tensor)
return collected_states
@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper):
Load saved optimizer states into parameter with given id.
"""
def cast(param, state_range, value, key=None):
def cast(param, state_range, value, global_shape, origin_shape, key=None):
"""
Make a copy of the needed segment of value and cast it to device of param.
"""
@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper):
)
if is_dtensor:
value = torch.reshape(value, global_shape)
global_shape = get_global_shape(real_param)
if is_padded_tensor(real_param):
value = torch.reshape(value, origin_shape)
padding_dim = real_param._padding_dim
value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)
if is_dtensor:
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
elif is_customized_distributed:
value = torch.reshape(value, global_shape)
@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper):
is_customized_distributed = is_customized_distributed_tensor(real_param)
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
device_mesh = get_device_mesh(real_param) if is_dtensor else None
global_shape = self.optimizer_params_info["id2shape"][param_id]
global_shape = self.params_info["id2shape"][param_id]
origin_shape = global_shape
for k, v in saved_states.items():
updated_states[k] = cast(fake_param, state_range, v, k)
updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
del v # clean loaded states
self.optim.state[fake_param].update(updated_states)

View File

@ -81,8 +81,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
optimizer.backward(loss)
optimizer.step()
for group in optimizer.param_groups:
group["lr"] = 0.1
optimizer.zero_grad()
with shared_tempdir() as tempdir:
model_ckpt_path = f"{tempdir}/model"
optimizer_ckpt_path = f"{tempdir}/optimizer"

View File

@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool):
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None)
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
assert dist_embedding_1d.num_embeddings == 64
assert dist_embedding_1d.num_embeddings == 128
assert dist_embedding_1d.embedding_dim == 32
assert embedding_copy.weight is dist_embedding_1d.weight

View File

@ -14,12 +14,14 @@ from torch.testing import assert_close
from colossalai.booster import Booster
from colossalai.booster.plugin import HybridParallelPlugin
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
from colossalai.checkpoint_io.utils import gather_distributed_param
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.shardformer._utils import getattr_
from colossalai.shardformer.policies.auto_policy import Policy
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
def build_model(
@ -247,11 +249,10 @@ def check_weight(
continue
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
sharded_weight_list = [
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
]
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
if is_padded_tensor(sharded_weight):
sharded_weight = to_unpadded_tensor(sharded_weight)
if verbose and dist.get_rank() == 0:
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")

View File

@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
# check weights
if test_config["precision"] == "fp32":
atol, rtol = 5e-4, 1e-3
# TODO he precision in weight checking is too significant.
atol, rtol = 1e-3, 1e-3
else:
atol, rtol = 5e-3, 5e-3
if stage_manager is None or stage_manager.is_first_stage():

View File

@ -0,0 +1,46 @@
import torch
from colossalai.device.device_mesh import DeviceMesh
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
from colossalai.testing import rerun_if_address_is_in_use, spawn
def check_padded_tensor(rank, world_size, port):
disable_existing_loggers()
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
original_tensor = torch.rand(32, 64).to("cuda")
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
assert padded_tensor.dist_layout == d_tensor.dist_layout
tensor_copy = padded_tensor.clone()
assert is_padded_tensor(tensor_copy)
assert is_distributed_tensor(tensor_copy)
tensor_detached = padded_tensor.detach()
assert is_padded_tensor(tensor_detached)
assert is_distributed_tensor(tensor_detached)
unpadded_tensor = to_unpadded_tensor(padded_tensor)
assert unpadded_tensor.shape == d_tensor.shape
assert is_distributed_tensor(unpadded_tensor)
global_tensor = to_global(unpadded_tensor)
assert global_tensor.shape == original_tensor.shape
@rerun_if_address_is_in_use()
def test_padded_tensor():
world_size = 4
spawn(check_padded_tensor, world_size)
if __name__ == "__main__":
test_padded_tensor()