mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] refactor embedding resize (#5603)
* [branch rebase] rebase main to Feature/resize_embedding (#5554) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [CI] run pre-commit (#5577) * fix * [release] update version (#5411) * [hotfix] fix typo s/keywrods/keywords etc. (#5429) * [devops] fix compatibility (#5444) * [devops] fix compatibility * [hotfix] update compatibility test on pr * [devops] fix compatibility * [devops] record duration during comp test * [test] decrease test duration * fix falcon * [shardformer] fix gathering output when using tensor parallelism (#5431) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * [doc] release Open-Sora 1.0 with model weights (#5468) * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] release Open-Sora 1.0 with model weights * [doc] update open-sora demo (#5479) * [doc] update open-sora demo * [doc] update open-sora demo * [doc] update open-sora demo * [example] add grok-1 inference (#5485) * [misc] add submodule * remove submodule * [example] support grok-1 tp inference * [example] add grok-1 inference script * [example] refactor code * [example] add grok-1 readme * [exmaple] add test ci * [exmaple] update readme * run pre-commit --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> * [rebase] rebase main to resize-embedding (#5581) * [release] grok-1 314b inference (#5490) * [release] grok-1 inference * [release] grok-1 inference * [release] grok-1 inference * [example] update Grok-1 inference (#5495) * revise grok-1 example * remove unused arg in scripts * prevent re-installing torch * update readme * revert modifying colossalai requirements * add perf * trivial * add tokenizer url * [hotfix] set return_outputs=False in examples and polish code (#5404) * fix: simplify merge_batch * fix: use return_outputs=False to eliminate extra memory consumption * feat: add return_outputs warning * style: remove `return_outputs=False` as it is the default value * [release] grok-1 inference benchmark (#5500) * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [release] grok-1 inference benchmark * [shardformer]Fix lm parallel. (#5480) * fix * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * fix lm forward distribution * fix * test ci * fix * [fix] fix grok-1 example typo (#5506) * [devops] fix example test ci (#5504) * Fix ColoTensorSpec for py11 (#5440) * fixed layout converter caching and updated tester * Empty-Commit * [shardformer] update colo attention to support custom mask (#5510) * [feature] refactor colo attention (#5462) * [extension] update api * [feature] add colo attention * [feature] update sdpa * [feature] update npu attention * [feature] update flash-attn * [test] add flash attn test * [test] update flash attn test * [shardformer] update modeling to fit colo attention (#5465) * [misc] refactor folder structure * [shardformer] update llama flash-attn * [shardformer] fix llama policy * [devops] update tensornvme install * [test] update llama test * [shardformer] update colo attn kernel dispatch * [shardformer] update blip2 * [shardformer] update chatglm * [shardformer] update gpt2 * [shardformer] update gptj * [shardformer] update opt * [shardformer] update vit * [shardformer] update colo attention mask prep * [shardformer] update whisper * [test] fix shardformer tests (#5514) * [test] fix shardformer tests * [test] fix shardformer tests * [format] applied code formatting on changed files in pull request 5510 (#5517) Co-authored-by: github-actions <github-actions@github.com> * [shardformer] fix pipeline forward error if custom layer distribution is used (#5189) * Use self.[distribute_layers|get_stage_index] to exploit custom layer distribution * Change static methods for t5 layer distribution to member functions * Change static methods for whisper layer distribution to member functions * Replace whisper policy usage with self one * Fix test case to use non-static layer distribution methods * fix: fix typo --------- Co-authored-by: Wenhao Chen <cwher@outlook.com> * [Fix] Grok-1 use tokenizer from the same pretrained path (#5532) * [fix] use tokenizer from the same pretrained path * trust remote code * [ColossalChat] Update RLHF V2 (#5286) * Add dpo. Fix sft, ppo, lora. Refactor all * fix and tested ppo * 2 nd round refactor * add ci tests * fix ci * fix ci * fix readme, style * fix readme style * fix style, fix benchmark * reproduce benchmark result, remove useless files * rename to ColossalChat * use new image * fix ci workflow * fix ci * use local model/tokenizer for ci tests * fix ci * fix ci * fix ci * fix ci timeout * fix rm progress bar. fix ci timeout * fix ci * fix ci typo * remove 3d plugin from ci temporary * test environment * cannot save optimizer * support chat template * fix readme * fix path * test ci locally * restore build_or_pr * fix ci data path * fix benchmark * fix ci, move ci tests to 3080, disable fast tokenizer * move ci to 85 * support flash attention 2 * add all-in-one data preparation script. Fix colossal-llama2-chat chat template * add hardware requirements * move ci test data * fix save_model, add unwrap * fix missing bos * fix missing bos; support grad accumulation with gemini * fix ci * fix ci * fix ci * fix llama2 chat template config * debug sft * debug sft * fix colossalai version requirement * fix ci * add sanity check to prevent NaN loss * fix requirements * add dummy data generation script * add dummy data generation script * add dummy data generation script * add dummy data generation script * update readme * update readme * update readme and ignore * fix logger bug * support parallel_output * modify data preparation logic * fix tokenization * update lr * fix inference * run pre-commit --------- Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [shardformer, pipeline] add `gradient_checkpointing_ratio` and heterogenous shard policy for llama (#5508) * feat: add `GradientCheckpointConfig` and `PipelineGradientCheckpointConfig` * feat: apply `GradientCheckpointConfig` to policy and llama_forward * feat: move `distribute_layer` and `get_stage_index` to PipelineStageManager * fix: add optional args for `distribute_layer` and `get_stage_index` * fix: fix changed API calls * test: update llama tests * style: polish `GradientCheckpointConfig` * fix: fix pipeline utils tests * fix incorrect sharding without zero (#5545) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [shardformer] Sequence Parallelism Optimization (#5533) * sequence parallel optimization * validate sequence parallel in llama (code to be polished) * shardformer api writing * integrate sequence parallel in ShardFormer * fix pp bugs and sp bugs for LlaMa model * integrating ring-based sequence parallelism into ShardFormer * [sequence parallelism]: Add fused megatron function * integrating ring-based sequence parallelism into ShardFormer --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * fix bugs when useing sp and flashattention together * fix operation function name * support flash attention for ulysses-style sp * clarify sp process group * fix compatibility bugs in moe plugin * fix fused linear bugs * fix linear layer test * support gpt model all-to-all sp * modify shard data dimension (meant to be dim=-1) * support megtron-style sp and distributed attn for llama model * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * finish sp mode 3 support for gpt * using all_to_all_single when batch size is 1 * support mode 2 sp in gpt2 (#5) * [shardformer] add megatron sp to llama * support llama7B 128k with distributed attention * [shardformer] robustness enhancement * add block attn * sp mode 1: keep input as a complete sequence * fix sp compatability * refactor ring implementation * support mode 2 sp in gpt2 * polish code * enable distributed attn mask when using sp mode 2 and 3 in llama * automatically enable flash attn when using sp mode 2 and 3 in llama * inplace attn mask * add zero2 support for sequence parallel * polish code * fix bugs * fix gemini checkpoint io * loose tensor checking atol and rtol * add comment * fix llama layernorm grad * fix zero grad * fix zero grad * fix conflict * update split and gather auto grad func * sequence parallel: inside text split (#6) * polish code (part 1) * polish code (part 2) * polish code (part 2.5) * polish code (part 3) * sequence parallel: inside text split * miscellaneous minor fixes * polish code * fix ulysses style ZeRO * sequence parallel: inside text split * miscellaneous minor fixes * disaggregate sp group and dp group for sp * fix llama and gpt sp * polish code * move ulysses grad sync to ddp (#9) * remove zero_stage and unbind the grad sync for alltoall sp * add 2d group creation test * move ulysses grad sync to ddp * add 2d group creation test * remove useless code * change shard config not to enable sp when enable_all_optimizations * add sp warnings for several model * remove useless code --------- Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> * [hotfix] quick fixes to make legacy tutorials runnable (#5559) Co-authored-by: Edenzzzz <wtan45@wisc.edu> * [fix] fix typo s/muiti-node /multi-node etc. (#5448) * [hotfix] fix typo s/get_defualt_parser /get_default_parser (#5548) * [devops] remove post commit ci (#5566) * [devops] remove post commit ci * [misc] run pre-commit on all files * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --------- Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> * [shardformer]enable padding vocabulary size. (#5489) * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix * fix fix fix * fix gather output * fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * revert * padding vocab * padding vocabe * fix * fix * fxi * test ci * fix fix fix fix * fix fix * fix * fix * Update hybrid_parallel_plugin.py fix fix fix * fix fix * fix fix * fix * resolve super init resolve super init resolve super init resolve super init * resolve comments * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * vocab checkpointio * padding vocab_size when using pipeline parallellism padding vocab_size when using pipeline parallellism fix fix * fix fix fix * fix * fix fix resize embedding fix resize embedding * fix resize embedding fix * revert * revert * padding vocab * fix * fix fix * fix fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ci * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * cherry-pick * revert moe modify * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix fix fix fix fix fix fix fix * resolve comments resolve comments resolve comments resolve comments resolve comments * ptensor ptensor resolve comments fix fix fix fix fix resolve comments resolve comments resolve comments resolve comments resolve comments --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix rebase * fix rebase --------- Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: Rocky Duan <dementrock@users.noreply.github.com> Co-authored-by: Edenzzzz <wtan45@wisc.edu> Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: Insu Jang <insujang@umich.edu> Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Zhongkai Zhao <kanezz620@gmail.com> Co-authored-by: linsj20 <linsj20@mails.tsinghua.edu.cn> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>pull/5619/head
parent
3788fefc7a
commit
a0ad587c24
|
@ -44,10 +44,10 @@ ZERO_AXIS, DP_AXIS, TP_AXIS = 0, 1, 2
|
|||
def get_param_info(optim: Optimizer):
|
||||
# Get a backup of necessary information of parameters for future use, which includes:
|
||||
# 1. A mapping from integer param_id to param32 shape.
|
||||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {"id2shape": {}}
|
||||
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
for param_id, param in enumerate(group["params"], start_index):
|
||||
|
@ -527,7 +527,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
dataloader: Optional[DataLoader] = None,
|
||||
lr_scheduler: Optional[LRScheduler] = None,
|
||||
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
||||
optimizer_params_info = get_param_info(optimizer)
|
||||
params_info = get_param_info(optimizer)
|
||||
if not isinstance(model, ModelWrapper):
|
||||
# convert model to sync bn
|
||||
# FIXME(ver217): gemini does not support sync bn
|
||||
|
@ -558,7 +558,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
**self.zero_optim_config,
|
||||
**self.optim_kwargs,
|
||||
tp_group=self.tp_group,
|
||||
optimizer_params_info=optimizer_params_info,
|
||||
params_info=params_info,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
|
||||
|
|
|
@ -213,12 +213,7 @@ def get_param_info(optim: Optimizer):
|
|||
|
||||
if optim is None:
|
||||
return {}
|
||||
param_info = {
|
||||
"param_groups": [],
|
||||
"param2id": {},
|
||||
"id2param": {},
|
||||
"param2shape": {},
|
||||
}
|
||||
param_info = {"param_groups": [], "param2id": {}, "id2param": {}, "param2shape": {}}
|
||||
start_index = 0
|
||||
for group in optim.param_groups:
|
||||
packed_group = {k: v for k, v in group.items() if k != "params"}
|
||||
|
@ -947,6 +942,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_model_chunks (int, optional): The number of model chunks for interleaved pipeline parallelism. Defaults to 1.
|
||||
gradient_checkpoint_config (GradientCheckpointConfig, optional): Configuration for gradient checkpointing. Defaults to None.
|
||||
enable_metadata_cache (bool, optional): Whether to enable metadata cache for pipeline parallelism. Defaults to True.
|
||||
make_vocab_size_divisible_by (int, optional): it's used when padding the vocabulary size, to make it choose an faster kenel. Default to 64.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -989,6 +986,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
num_model_chunks: int = 1,
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None,
|
||||
enable_metadata_cache: bool = True,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert (
|
||||
|
@ -1095,6 +1093,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
sequence_parallelism_mode=sequence_parallelism_mode,
|
||||
enable_sequence_overlap=enable_sequence_overlap,
|
||||
parallel_output=parallel_output,
|
||||
make_vocab_size_divisible_by=make_vocab_size_divisible_by,
|
||||
gradient_checkpoint_config=gradient_checkpoint_config,
|
||||
)
|
||||
self.amp_config = dict(
|
||||
|
|
|
@ -14,6 +14,12 @@ from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|||
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.tensor.padded_tensor import (
|
||||
init_as_padded_tensor,
|
||||
is_padded_tensor,
|
||||
to_padded_tensor,
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
from .general_checkpoint_io import GeneralCheckpointIO
|
||||
|
@ -32,6 +38,7 @@ from .utils import (
|
|||
save_param_groups,
|
||||
save_state_dict,
|
||||
save_state_dict_shards,
|
||||
search_padding_dim,
|
||||
search_tp_partition_dim,
|
||||
sharded_optimizer_loading_epilogue,
|
||||
)
|
||||
|
@ -89,6 +96,8 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if param is None:
|
||||
continue
|
||||
# Gather tensor pieces when using tensor parallel.
|
||||
if is_padded_tensor(param):
|
||||
param = to_unpadded_tensor(param)
|
||||
param_ = gather_distributed_param(param, keep_vars=False)
|
||||
block, block_size = state_dict_sharder.append_param(prefix + name, param_)
|
||||
if block is not None:
|
||||
|
@ -231,7 +240,6 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
# When pipeline is used, each stage produces its own shard files and index files.
|
||||
# Index files belonging to each stage are saved under a temporary folder ./tmp_index_files/
|
||||
# After all the state_dicts have been saved, the master rank integrates all the index files into one final index file and deletes the tmp folder.
|
||||
|
||||
final_index_file_path = copy.deepcopy(save_index_file)
|
||||
tmp_index_file_folder = os.path.join(checkpoint, "tmp_index_files")
|
||||
Path(tmp_index_file_folder).mkdir(parents=True, exist_ok=True)
|
||||
|
@ -251,6 +259,7 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
use_safetensors=use_safetensors,
|
||||
use_pp_format=True,
|
||||
)
|
||||
|
||||
if control_saving:
|
||||
assert (
|
||||
self.dp_rank == 0 and self.tp_rank == 0
|
||||
|
@ -867,6 +876,11 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
dist.all_gather(gather_tensor, v, group=tp_group)
|
||||
v = torch.cat(gather_tensor, dim=partition_dim)
|
||||
|
||||
padding_dim = search_padding_dim(v.shape, original_shape)
|
||||
if padding_dim is not None:
|
||||
v = init_as_padded_tensor(v, v.shape[padding_dim], original_shape[padding_dim], padding_dim)
|
||||
v = to_unpadded_tensor(v)
|
||||
|
||||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
@ -899,6 +913,19 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
if isinstance(v, torch.Tensor) and k != "step":
|
||||
# Shard state along tensor parallel group.
|
||||
partition_dim = search_tp_partition_dim(current_shape, original_shape, self.tp_size)
|
||||
global_shape = current_shape
|
||||
if partition_dim is not None:
|
||||
# pad embedding params
|
||||
global_shape = (
|
||||
*current_shape[:partition_dim],
|
||||
current_shape[partition_dim] * self.tp_size,
|
||||
*current_shape[partition_dim + 1 :],
|
||||
)
|
||||
|
||||
padding_dim = search_padding_dim(global_shape, original_shape)
|
||||
if padding_dim is not None:
|
||||
v = to_padded_tensor(v, global_shape[padding_dim], padding_dim)
|
||||
|
||||
if partition_dim is not None:
|
||||
slice_size = current_shape[partition_dim]
|
||||
v = v.split(slice_size, dim=partition_dim)[self.tp_rank]
|
||||
|
|
|
@ -120,6 +120,15 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
|
|||
return partition_dim
|
||||
|
||||
|
||||
def search_padding_dim(global_shape: torch.Size, original_shape: torch.Size) -> Optional[int]:
|
||||
padding_dim = None
|
||||
for dim, length in enumerate(global_shape):
|
||||
if length > original_shape[dim]:
|
||||
padding_dim = dim
|
||||
break
|
||||
return padding_dim
|
||||
|
||||
|
||||
# ======================================
|
||||
# Helper classes and functions for saving shard file
|
||||
# ======================================
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
from ._operation import all_to_all_comm
|
||||
from .attn import AttnMaskType, ColoAttention
|
||||
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
|
||||
from .embedding import Embedding1D, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row
|
||||
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
|
||||
from .linear import Linear1D_Col, Linear1D_Row, PaddingLMHead, VocabParallelLMHead1D
|
||||
from .loss import cross_entropy_1d
|
||||
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
|
||||
from .parallel_module import ParallelModule
|
||||
|
@ -25,6 +25,9 @@ __all__ = [
|
|||
"FusedRMSNorm",
|
||||
"FusedLinear1D_Col",
|
||||
"ParallelModule",
|
||||
"PaddingEmbedding",
|
||||
"PaddingLMHead",
|
||||
"VocabParallelLMHead1D",
|
||||
"AttnMaskType",
|
||||
"ColoAttention",
|
||||
"all_to_all_comm",
|
||||
|
|
|
@ -21,10 +21,10 @@ from colossalai.tensor.d_tensor.api import (
|
|||
)
|
||||
|
||||
from ._operation import gather_forward_split_backward, reduce_forward
|
||||
from .parallel_module import ParallelModule
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ["Embedding1D", "VocabParallelEmbedding1D"]
|
||||
__all__ = ["Embedding1D", "VocabParallelEmbedding1D", "PaddingEmbedding"]
|
||||
|
||||
|
||||
class Embedding1D(ParallelModule):
|
||||
|
@ -161,7 +161,80 @@ class Embedding1D(ParallelModule):
|
|||
return output_parallel
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(ParallelModule):
|
||||
class PaddingEmbedding(PaddingParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int = None,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
self.padding_idx = padding_idx
|
||||
if num_embeddings % make_vocab_size_divisible_by != 0:
|
||||
self.num_embeddings = (
|
||||
num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by)
|
||||
)
|
||||
# create weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
super().__init__(self.num_embeddings, num_embeddings, weight)
|
||||
|
||||
if weight is None:
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
init.normal_(self.weight)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
def _fill_padding_idx_with_zero(self) -> None:
|
||||
if self.padding_idx is not None:
|
||||
with torch.no_grad():
|
||||
self.weight[self.padding_idx].fill_(0)
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
return F.embedding(input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the origin attributes
|
||||
num_embeddings = module.num_embeddings
|
||||
embedding_dim = module.embedding_dim
|
||||
padding_idx = module.padding_idx
|
||||
device = module.weight.device
|
||||
# create the parallel module
|
||||
padding_embedding = PaddingEmbedding(
|
||||
num_embeddings=num_embeddings,
|
||||
embedding_dim=embedding_dim,
|
||||
padding_idx=padding_idx,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return padding_embedding
|
||||
|
||||
|
||||
class VocabParallelEmbedding1D(PaddingParallelModule):
|
||||
r"""Embedding parallelized in the vocabulary dimension.
|
||||
|
||||
Args:
|
||||
|
@ -201,10 +274,10 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
process_group: ProcessGroup = None,
|
||||
weight: Optional[nn.Parameter] = None,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.embed_args = args
|
||||
|
@ -214,8 +287,23 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
tensor_parallel_rank = dist.get_rank(group=process_group)
|
||||
|
||||
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
|
||||
self.num_embeddings = self.num_embeddings_per_partition
|
||||
# generate weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
# calculate new padding size
|
||||
multiple = make_vocab_size_divisible_by * tensor_parallel_size
|
||||
if num_embeddings % multiple != 0:
|
||||
self.num_embeddings = num_embeddings + multiple - (num_embeddings % multiple)
|
||||
|
||||
# resize vocabulary size
|
||||
super().__init__(self.num_embeddings, num_embeddings, weight)
|
||||
|
||||
# deal with tensor parallelism
|
||||
self.num_embeddings_per_partition = divide(self.num_embeddings, tensor_parallel_size)
|
||||
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
|
||||
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
|
||||
|
||||
|
@ -226,13 +314,6 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
seed = torch.random.initial_seed()
|
||||
self.randomizer = create_randomizer_with_offset(seed, process_group=self.process_group)
|
||||
|
||||
# parameter
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
self.weight = nn.Parameter(torch.empty((num_embeddings, self.embedding_dim), **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
@ -243,7 +324,7 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Embedding, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native pytorch embedding module to a parallel module.
|
||||
"""
|
||||
|
@ -303,11 +384,9 @@ class VocabParallelEmbedding1D(ParallelModule):
|
|||
# Mask the input.
|
||||
masked_input = input_.clone() - self.vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
output_parallel = F.embedding(
|
||||
masked_input, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs
|
||||
)
|
||||
|
||||
# Mask the output embedding.
|
||||
embedding_output = output_parallel.clone()
|
||||
embedding_output[input_mask, :] = 0.0
|
||||
|
|
|
@ -32,7 +32,7 @@ from ._operation import (
|
|||
reducescatter_forward_gather_backward,
|
||||
split_forward_gather_backward,
|
||||
)
|
||||
from .parallel_module import ParallelModule
|
||||
from .parallel_module import PaddingParallelModule, ParallelModule
|
||||
from .utils import create_randomizer_with_offset
|
||||
|
||||
__all__ = ["Linear1D_Col", "Linear1D_Row"]
|
||||
|
@ -84,8 +84,9 @@ class Linear1D_Col(ParallelModule):
|
|||
bias_: Optional[Parameter] = None,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
super().__init__(weight=weight, bias_=bias_, **kwargs)
|
||||
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
|
@ -118,6 +119,7 @@ class Linear1D_Col(ParallelModule):
|
|||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
self.weight = weight
|
||||
|
||||
if not is_distributed_tensor(self.weight):
|
||||
sharded_weight = shard_rowwise(self.weight.data, self.process_group)
|
||||
sharded_tensor_to_existing_param(sharded_weight, self.weight)
|
||||
|
@ -140,7 +142,7 @@ class Linear1D_Col(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
|
@ -173,7 +175,6 @@ class Linear1D_Col(ParallelModule):
|
|||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -322,7 +323,7 @@ class Linear1D_Row(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> ParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
|
@ -356,7 +357,6 @@ class Linear1D_Row(ParallelModule):
|
|||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -439,3 +439,211 @@ class Linear1D_Row(ParallelModule):
|
|||
return output
|
||||
else:
|
||||
return output, self.bias
|
||||
|
||||
|
||||
class PaddingLMHead(PaddingParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||||
):
|
||||
# Keep input parameters
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
if out_features % make_vocab_size_divisible_by != 0:
|
||||
self.out_features = (
|
||||
out_features + make_vocab_size_divisible_by - (out_features % make_vocab_size_divisible_by)
|
||||
)
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
|
||||
else:
|
||||
weight.data = weight.data.to(device=device, dtype=dtype)
|
||||
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_.data = bias_.data.to(device=device, dtype=dtype)
|
||||
else:
|
||||
bias_ = None
|
||||
|
||||
# resize embeddings
|
||||
super().__init__(self.out_features, out_features, weight, bias_)
|
||||
|
||||
if weight is None:
|
||||
self.reset_parameters(weight_initializer, bias_initializer)
|
||||
|
||||
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
|
||||
fan_in, fan_out = self.in_features, self.out_features
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
if self.bias is not None:
|
||||
bias_initializer(self.bias, fan_in=fan_in)
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
# ensure only one process group is passed
|
||||
|
||||
lm_head_linear = PaddingLMHead(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return lm_head_linear
|
||||
|
||||
def forward(self, input: Tensor) -> Tensor:
|
||||
output = F.linear(input, self.weight, self.bias)
|
||||
output = output[..., : self.old_num_embeddings]
|
||||
return output
|
||||
|
||||
|
||||
class VocabParallelLMHead1D(Linear1D_Col, PaddingParallelModule):
|
||||
r"""Linear layer with column parallelism.
|
||||
|
||||
The linear layer is defined as :math:`Y = XA + b`. A is parallelized along
|
||||
its second dimension as :math:`A = [A_1, ..., A_p]`.
|
||||
|
||||
Args:
|
||||
in_features (int): size of each input sample.
|
||||
out_features (int): size of each output sample.
|
||||
bias (bool, optional): If set to ``False``, the layer will not learn an additive bias, defaults to ``True``.
|
||||
dtype (`torch.dtype`): The dtype of parameters, defaults to None.
|
||||
device (`torch.device`): The device of parameters, defaults to None.
|
||||
process_group (`torch.distributed.ProcessGroup`): The process group to be used for weight sharding and communication, defaults to None.
|
||||
gather_output (bool, optional): If true, call all-gather on output and make Y available
|
||||
to all GPUs, otherwise, every GPU will have its output
|
||||
which is :math:`Y_i = XA_i`, defaults to False
|
||||
seq_parallel (`bool`): If set to ``True``, it will use sequence parallel, defaults to False.
|
||||
overlap (`bool`): If set to ``True``, it will overlap input all-gather with gradient computation during backward, defaults to False.
|
||||
skip_bias_add (bool): If set to ``True``, it will skip bias add for linear layer,
|
||||
which is preserved for kernel fusion, defaults to False
|
||||
weight_initializer (`typing.Callable`):
|
||||
The initializer of weight, defaults to kaiming uniform initializer.
|
||||
bias_initializer (`typing.Callable`):
|
||||
The initializer of bias, defaults to xavier uniform initializer.
|
||||
|
||||
More details about ``initializer`` please refer to
|
||||
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias: bool = True,
|
||||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
weight: Optional[Parameter] = None,
|
||||
bias_: Optional[Parameter] = None,
|
||||
make_vocab_size_divisible_by: int = 64,
|
||||
**kwargs,
|
||||
):
|
||||
# create weight and bias
|
||||
if weight is None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
weight = Parameter(torch.empty(out_features, self.in_features, **factory_kwargs))
|
||||
if bias:
|
||||
if bias_ is None:
|
||||
bias_ = Parameter(torch.empty(out_features, **factory_kwargs))
|
||||
else:
|
||||
bias_ = None
|
||||
|
||||
# calculate new vocab size
|
||||
self.tensor_parallel_size = dist.get_world_size(group=process_group)
|
||||
new_out_features = out_features
|
||||
multiple = make_vocab_size_divisible_by * self.tensor_parallel_size
|
||||
if out_features % multiple != 0:
|
||||
new_out_features = out_features + multiple - (out_features % multiple)
|
||||
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
out_features=new_out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=weight,
|
||||
bias_=bias_,
|
||||
**kwargs,
|
||||
new_num_embeddings=new_out_features,
|
||||
old_num_embeddings=out_features,
|
||||
)
|
||||
# get the length of valid embeddings
|
||||
tp_rank = dist.get_rank(process_group)
|
||||
partition_size = self.new_num_embeddings // dist.get_world_size(process_group)
|
||||
if self.old_num_embeddings >= (tp_rank + 1) * partition_size:
|
||||
self.num_valid_embeddings_local = partition_size
|
||||
elif self.old_num_embeddings >= tp_rank * partition_size:
|
||||
self.num_valid_embeddings_local = self.old_num_embeddings - tp_rank * partition_size
|
||||
else:
|
||||
self.num_valid_embeddings_local = 0
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: nn.Linear, process_group: Union[ProcessGroup, List[ProcessGroup]], **kwargs
|
||||
) -> PaddingParallelModule:
|
||||
r"""
|
||||
Convert a native PyTorch linear layer to a parallelized linear layer.
|
||||
"""
|
||||
LazyInitContext.materialize(module)
|
||||
# get the attributes
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
bias = module.bias is not None
|
||||
device = module.weight.device
|
||||
|
||||
lm_head_linear = VocabParallelLMHead1D(
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
process_group=process_group,
|
||||
weight=module.weight,
|
||||
bias_=module.bias,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return lm_head_linear
|
||||
|
||||
def forward(self, input_: Tensor) -> Tuple[Tensor, Tensor]:
|
||||
# get forward output
|
||||
if self.skip_bias_add:
|
||||
output, bias = super().forward(input_)
|
||||
else:
|
||||
output = super().forward(input_)
|
||||
|
||||
# delete the padding of output
|
||||
if self.gather_output:
|
||||
output = output[..., : self.old_num_embeddings]
|
||||
else:
|
||||
output = output[..., : self.num_valid_embeddings_local]
|
||||
|
||||
# return
|
||||
if self.skip_bias_add:
|
||||
return output, bias
|
||||
return output
|
||||
|
|
|
@ -15,7 +15,14 @@ class DistCrossEntropy(Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
|
||||
def forward(
|
||||
ctx,
|
||||
vocab_logits: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
ignore_index: int,
|
||||
process_group: ProcessGroup,
|
||||
vocab_size: int,
|
||||
):
|
||||
r"""
|
||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||
|
@ -41,15 +48,21 @@ class DistCrossEntropy(Function):
|
|||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||
|
||||
# mask the target in the local device
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
rank = dist.get_rank(group=process_group)
|
||||
world_size = dist.get_world_size(group=process_group)
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
if vocab_size == None:
|
||||
partition_vocab_size = vocab_logits.size()[-1]
|
||||
global_vocab_size = partition_vocab_size * world_size
|
||||
else:
|
||||
global_vocab_size = vocab_size
|
||||
partition_vocab_size = global_vocab_size // world_size
|
||||
|
||||
# [down, up) => false, other device and -100 => true
|
||||
delta = (global_vocab_size + world_size - 1) // world_size
|
||||
down_threshold = rank * delta
|
||||
up_threshold = down_threshold + delta
|
||||
if up_threshold > global_vocab_size:
|
||||
up_threshold = global_vocab_size
|
||||
mask = (target < down_threshold) | (target >= up_threshold)
|
||||
masked_target = target.clone() - down_threshold
|
||||
masked_target[mask] = 0
|
||||
|
@ -57,7 +70,8 @@ class DistCrossEntropy(Function):
|
|||
# reshape the logits and target
|
||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||
# reshape the labels to [bath_size * seq_len]
|
||||
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||
self_vocab_size = vocab_logits.size()[-1]
|
||||
logits_2d = vocab_logits.view(-1, self_vocab_size)
|
||||
masked_target_1d = masked_target.view(-1)
|
||||
|
||||
# extract the x[class] and set the x[other device] to zero
|
||||
|
@ -104,10 +118,14 @@ class DistCrossEntropy(Function):
|
|||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||
|
||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||
return grad_logits, None, None, None
|
||||
return grad_logits, None, None, None, None
|
||||
|
||||
|
||||
def cross_entropy_1d(
|
||||
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
|
||||
vocab_logits: torch.Tensor,
|
||||
labels: torch.Tensor,
|
||||
ignore_index: int = -100,
|
||||
process_group: ProcessGroup = None,
|
||||
vocab_size: int = None,
|
||||
) -> torch.Tensor:
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
|
||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size)
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
import itertools
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -20,11 +20,15 @@ from colossalai.tensor.d_tensor import (
|
|||
is_distributed_tensor,
|
||||
sharded_tensor_to_param,
|
||||
)
|
||||
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
|
||||
|
||||
__all__ = ["ParallelModule"]
|
||||
|
||||
|
||||
class ParallelModule(nn.Module, ABC):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
|
||||
|
@ -54,7 +58,7 @@ class ParallelModule(nn.Module, ABC):
|
|||
"""
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars)
|
||||
destination[prefix + name] = gather_distributed_param(param, keep_vars=keep_vars).data
|
||||
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
|
@ -171,3 +175,187 @@ class ParallelModule(nn.Module, ABC):
|
|||
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
|
||||
class PaddingParallelModule(ParallelModule):
|
||||
def __init__(
|
||||
self,
|
||||
new_num_embeddings: int,
|
||||
old_num_embeddings: int,
|
||||
weight: Optional[nn.Parameter],
|
||||
bias_: Optional[nn.Parameter] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.new_num_embeddings = new_num_embeddings
|
||||
self.old_num_embeddings = old_num_embeddings
|
||||
self.weight = weight
|
||||
self.bias = bias_
|
||||
|
||||
if not (is_distributed_tensor(self.weight) or self.weight.shape[0] == self.new_num_embeddings):
|
||||
self.resize_embedding_weight()
|
||||
|
||||
if self.bias is not None and not (
|
||||
is_distributed_tensor(self.bias) or self.bias.shape[0] == self.new_num_embeddings
|
||||
):
|
||||
self.resize_embedding_bias()
|
||||
|
||||
@abstractmethod
|
||||
def from_native_module(
|
||||
module: nn.Module, process_group: Union[ProcessGroup, List[ProcessGroup]] = None
|
||||
) -> "PaddingParallelModule":
|
||||
"""
|
||||
Convert a native PyTorch module to a parallelized module.
|
||||
|
||||
Args:
|
||||
module (nn.Module): the module to be converted.
|
||||
process_group (ProcessGroup or list[ProcessGroup]): the process group(s) to be used for communication.
|
||||
If this is a list, the process group at the ith index of the list will correspond to the process group
|
||||
in the ith axis of the device mesh. Defaults to None, which means the global process group.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _save_to_state_dict(self, destination, prefix, keep_vars):
|
||||
r"""Saves module state to `destination` dictionary, containing a state
|
||||
of the module, but not its descendants. This is called on every
|
||||
submodule in :meth:`~torch.nn.Module.state_dict`.
|
||||
|
||||
In rare cases, subclasses can achieve class-specific behavior by
|
||||
overriding this method with custom logic.
|
||||
|
||||
Args:
|
||||
destination (dict): a dict where state will be stored
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
"""
|
||||
for name, param in self._parameters.items():
|
||||
if param is not None:
|
||||
param = gather_distributed_param(param, keep_vars=keep_vars)
|
||||
if is_padded_tensor(param):
|
||||
param = to_unpadded_tensor(param)
|
||||
destination[prefix + name] = param.data
|
||||
|
||||
for name, buf in self._buffers.items():
|
||||
if buf is not None and name not in self._non_persistent_buffers_set:
|
||||
destination[prefix + name] = buf if keep_vars else buf.detach()
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "get_extra_state", Module.get_extra_state) is not Module.get_extra_state:
|
||||
destination[extra_state_key] = self.get_extra_state()
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
r"""Copies parameters and buffers from :attr:`state_dict` into only
|
||||
this module, but not its descendants. This is called on every submodule
|
||||
in :meth:`~torch.nn.Module.load_state_dict`. Metadata saved for this
|
||||
module in input :attr:`state_dict` is provided as :attr:`local_metadata`.
|
||||
For state dicts without metadata, :attr:`local_metadata` is empty.
|
||||
Subclasses can achieve class-specific backward compatible loading using
|
||||
the version number at `local_metadata.get("version", None)`.
|
||||
|
||||
.. note::
|
||||
:attr:`state_dict` is not the same object as the input
|
||||
:attr:`state_dict` to :meth:`~torch.nn.Module.load_state_dict`. So
|
||||
it can be modified.
|
||||
|
||||
Args:
|
||||
state_dict (dict): a dict containing parameters and
|
||||
persistent buffers.
|
||||
prefix (str): the prefix for parameters and buffers used in this
|
||||
module
|
||||
local_metadata (dict): a dict containing the metadata for this module.
|
||||
See
|
||||
strict (bool): whether to strictly enforce that the keys in
|
||||
:attr:`state_dict` with :attr:`prefix` match the names of
|
||||
parameters and buffers in this module
|
||||
missing_keys (list of str): if ``strict=True``, add missing keys to
|
||||
this list
|
||||
unexpected_keys (list of str): if ``strict=True``, add unexpected
|
||||
keys to this list
|
||||
error_msgs (list of str): error messages should be added to this
|
||||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
for name, param in local_state.items():
|
||||
key = prefix + name
|
||||
|
||||
if key in state_dict:
|
||||
input_param = state_dict[key]
|
||||
if not torch.overrides.is_tensor_like(input_param):
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"expected torch.Tensor or Tensor-like object from checkpoint but "
|
||||
"received {}".format(key, type(input_param))
|
||||
)
|
||||
continue
|
||||
|
||||
if is_padded_tensor(param):
|
||||
input_param = to_padded_tensor(input_param, param._current_length, param._padding_dim)
|
||||
|
||||
if is_distributed_tensor(param):
|
||||
# shard the input param
|
||||
device_mesh = get_device_mesh(param)
|
||||
sharding_spec = get_sharding_spec(param)
|
||||
sharded_tensor = distribute_tensor(input_param, device_mesh, sharding_spec)
|
||||
input_param = sharded_tensor_to_param(sharded_tensor)
|
||||
elif is_customized_distributed_tensor(param):
|
||||
input_param = distribute_tensor_with_customization(input_param, param.shard_fn, param.gather_fn)
|
||||
|
||||
# This is used to avoid copying uninitialized parameters into
|
||||
# non-lazy modules, since they dont have the hook to do the checks
|
||||
# in such case, it will error when accessing the .shape attribute.
|
||||
is_param_lazy = torch.nn.parameter.is_lazy(param)
|
||||
# Backward compatibility: loading 1-dim tensor from 0.3.* to version 0.4+
|
||||
if not is_param_lazy and len(param.shape) == 0 and len(input_param.shape) == 1:
|
||||
input_param = input_param[0]
|
||||
|
||||
if not is_param_lazy and input_param.shape != param.shape:
|
||||
# local shape should match the one in checkpoint
|
||||
error_msgs.append(
|
||||
"size mismatch for {}: copying a param with shape {} from checkpoint, "
|
||||
"the shape in current model is {}.".format(key, input_param.shape, param.shape)
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
elif strict:
|
||||
missing_keys.append(key)
|
||||
|
||||
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
|
||||
if getattr(self.__class__, "set_extra_state", Module.set_extra_state) is not Module.set_extra_state:
|
||||
if extra_state_key in state_dict:
|
||||
self.set_extra_state(state_dict[extra_state_key])
|
||||
elif strict:
|
||||
missing_keys.append(extra_state_key)
|
||||
elif strict and (extra_state_key in state_dict):
|
||||
unexpected_keys.append(extra_state_key)
|
||||
|
||||
if strict:
|
||||
for key in state_dict.keys():
|
||||
if key.startswith(prefix) and key != extra_state_key:
|
||||
input_name = key[len(prefix) :]
|
||||
input_name = input_name.split(".", 1)[0] # get the name of param/buffer/child
|
||||
if input_name not in self._modules and input_name not in local_state:
|
||||
unexpected_keys.append(key)
|
||||
|
||||
def resize_embedding_weight(self):
|
||||
self.weight = to_padded_tensor(self.weight, self.new_num_embeddings, 0)
|
||||
|
||||
def resize_embedding_bias(self):
|
||||
self.bias = to_padded_tensor(self.bias, self.new_num_embeddings, 0)
|
||||
|
|
|
@ -26,7 +26,6 @@ from colossalai.shardformer.layer._operation import gather_forward_split_backwar
|
|||
from colossalai.shardformer.shard import ShardConfig
|
||||
|
||||
from ..layer import cross_entropy_1d
|
||||
from ..layer._operation import gather_forward_split_backward
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
@ -397,13 +396,11 @@ class GPT2PipelineForwards:
|
|||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
else:
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
@ -1301,12 +1298,12 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
|
||||
shift_labels = shift_labels.view(-1)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
|
||||
if not shard_config.parallel_output:
|
||||
lm_logits = gather_forward_split_backward(lm_logits, -1, shard_config.tensor_parallel_process_group)
|
||||
|
||||
if not return_dict:
|
||||
output = (lm_logits,) + transformer_outputs[1:]
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
|
|
@ -316,7 +316,10 @@ class LlamaPipelineForwards:
|
|||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
else:
|
||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||
|
@ -735,11 +738,13 @@ def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
|
|||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
|
||||
new_vocab_size = logits.shape[-1]
|
||||
shift_logits = shift_logits.view(-1, new_vocab_size)
|
||||
loss = cross_entropy_1d(
|
||||
shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group
|
||||
shift_logits,
|
||||
shift_labels,
|
||||
process_group=shard_config.tensor_parallel_process_group,
|
||||
vocab_size=self.lm_head.out_features,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
|
|
|
@ -195,3 +195,12 @@ class Policy(ABC):
|
|||
List[Dict[int, Tensor]]: List of parameters that should be shared across stages. E.g. [{0: module.model.embed_tokens.weight, 3: module.lm_head.weight}]
|
||||
"""
|
||||
return []
|
||||
|
||||
def tie_weight_check(self):
|
||||
input_embedding = self.model.get_input_embeddings()
|
||||
output_embedding = self.model.get_output_embeddings()
|
||||
return (
|
||||
input_embedding is not None
|
||||
and output_embedding is not None
|
||||
and id(input_embedding.weight) == id(output_embedding.weight)
|
||||
)
|
||||
|
|
|
@ -37,17 +37,7 @@ class BertPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -62,6 +52,13 @@ class BertPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
|
@ -150,10 +147,6 @@ class BertPolicy(Policy):
|
|||
|
||||
policy[BertEmbeddings] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="dropout",
|
||||
target_module=col_nn.DropoutForReplicatedInput,
|
||||
|
@ -168,6 +161,18 @@ class BertPolicy(Policy):
|
|||
target_key=BertModel,
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
)
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BertEmbeddings,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
# Handle bert layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -237,8 +242,21 @@ class BertPolicy(Policy):
|
|||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="decoder",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True},
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead,
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="decoder",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=BertLMPredictionHead,
|
||||
|
|
|
@ -17,16 +17,7 @@ class BlipPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
# TODO:
|
||||
vocab_size = self.model.config.qformer_config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -43,6 +34,13 @@ class BlipPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
|
@ -202,22 +200,48 @@ class BlipPolicy(Policy):
|
|||
],
|
||||
)
|
||||
|
||||
policy[OPTForCausalLM] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.embed_tokens",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
policy[Blip2Attention] = ModulePolicyDescription(method_replacement={"forward": forward_fn()})
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="model.decoder.embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
# optimization configuration
|
||||
# Handle Blip2EncoderLayer layer
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
|
|
@ -35,16 +35,7 @@ class BloomPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -52,6 +43,13 @@ class BloomPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
|
@ -112,12 +110,19 @@ class BloomPolicy(Policy):
|
|||
method_replacement={
|
||||
"build_alibi_tensor": build_bloom_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
|
||||
},
|
||||
sub_module_replacement=[
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=BloomModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
|
@ -282,7 +287,21 @@ class BloomForCausalLMPolicy(BloomPolicy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=BloomForCausalLM,
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=BloomForCausalLM,
|
||||
|
|
|
@ -25,20 +25,12 @@ class ChatGLMPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# Resize embedding
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.padded_vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
# the batch_size_dim is bounded to Model
|
||||
bsz_dim = 1
|
||||
setattr(self.model, "batch_size_dim", bsz_dim)
|
||||
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
|
@ -46,6 +38,13 @@ class ChatGLMPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
if self.model.config.rmsnorm:
|
||||
norm_cls = col_nn.FusedRMSNorm
|
||||
|
@ -68,16 +67,6 @@ class ChatGLMPolicy(Policy):
|
|||
sp_partial_derived = sp_mode == "split_gather"
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[ChatGLMModel] = ModulePolicyDescription(
|
||||
attribute_replacement={},
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embedding.word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
policy[GLMBlock] = ModulePolicyDescription(
|
||||
attribute_replacement={
|
||||
"self_attention.num_attention_heads_per_partition": self.model.config.num_attention_heads
|
||||
|
@ -114,6 +103,19 @@ class ChatGLMPolicy(Policy):
|
|||
),
|
||||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embedding.word_embeddings",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=ChatGLMModel,
|
||||
)
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
|
|
|
@ -32,16 +32,7 @@ class FalconPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -58,6 +49,14 @@ class FalconPolicy(Policy):
|
|||
warnings.warn("Falcon doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
attn_attribute_replacement = {
|
||||
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
|
@ -98,12 +97,19 @@ class FalconPolicy(Policy):
|
|||
method_replacement={
|
||||
"build_alibi_tensor": build_falcon_alibi_tensor_fn(self.shard_config.tensor_parallel_process_group)
|
||||
},
|
||||
sub_module_replacement=[
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="word_embeddings",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
)
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
],
|
||||
policy=policy,
|
||||
target_key=FalconModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
|
@ -232,11 +238,26 @@ class FalconForCausalLMPolicy(FalconPolicy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=FalconForCausalLM,
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=FalconForCausalLM,
|
||||
)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
self.set_pipeline_forward(
|
||||
model_cls=FalconForCausalLM,
|
||||
|
|
|
@ -34,12 +34,7 @@ class GPT2Policy(Policy):
|
|||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -47,6 +42,13 @@ class GPT2Policy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
|
@ -73,10 +75,6 @@ class GPT2Policy(Policy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[GPT2Model] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="drop",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
|
@ -137,6 +135,17 @@ class GPT2Policy(Policy):
|
|||
),
|
||||
],
|
||||
)
|
||||
if embedding_cls is not None:
|
||||
# padding vocabulary size when using pp to make it divisible by shard_config.make_vocab_size_divisible_by
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPT2Model,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -298,8 +307,11 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": not self.shard_config.parallel_output},
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": False,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -308,7 +320,19 @@ class GPT2LMHeadModelPolicy(GPT2Policy):
|
|||
addon_module[GPT2LMHeadModel].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
else:
|
||||
addon_module = {
|
||||
GPT2LMHeadModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
|
@ -353,13 +377,28 @@ class GPT2DoubleHeadsModelPolicy(GPT2Policy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True},
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
else:
|
||||
addon_module = {
|
||||
GPT2DoubleHeadsModel: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
module_policy.update(addon_module)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
|
|
|
@ -29,22 +29,21 @@ class GPTJPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJBlock, GPTJModel
|
||||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn("GPTJ doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
@ -54,10 +53,6 @@ class GPTJPolicy(Policy):
|
|||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[GPTJModel] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="drop",
|
||||
target_module=col_nn.DropoutForParallelInput,
|
||||
|
@ -126,6 +121,17 @@ class GPTJPolicy(Policy):
|
|||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="wte",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=GPTJModel,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
self.append_or_create_submodule_replacement(
|
||||
|
@ -255,13 +261,28 @@ class GPTJForCausalLMPolicy(GPTJPolicy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True},
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(addon_module)
|
||||
else:
|
||||
addon_module = {
|
||||
GPTJForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
policy.update(addon_module)
|
||||
|
||||
if self.pipeline_stage_manager is not None:
|
||||
self.set_pipeline_forward(
|
||||
|
|
|
@ -6,7 +6,16 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from torch.nn import Module
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, RMSNorm, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.llama import (
|
||||
LlamaPipelineForwards,
|
||||
|
@ -26,15 +35,7 @@ class LlamaPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
|
@ -42,6 +43,13 @@ class LlamaPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedRMSNorm
|
||||
else:
|
||||
|
@ -167,10 +175,12 @@ class LlamaPolicy(Policy):
|
|||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=LlamaModel,
|
||||
|
@ -327,8 +337,11 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs={"gather_output": not self.shard_config.parallel_output},
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": not self.shard_config.parallel_output,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
@ -337,7 +350,19 @@ class LlamaForCausalLMPolicy(LlamaPolicy):
|
|||
new_item[LlamaForCausalLM].method_replacement = {
|
||||
"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)
|
||||
}
|
||||
policy.update(new_item)
|
||||
else:
|
||||
new_item = {
|
||||
LlamaForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
)
|
||||
],
|
||||
)
|
||||
}
|
||||
policy.update(new_item)
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
# set None as default
|
||||
|
|
|
@ -3,7 +3,15 @@ from typing import Dict, Union
|
|||
|
||||
import torch.nn as nn
|
||||
|
||||
from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from ..modeling.mistral import get_mistral_flash_attention_forward
|
||||
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
@ -16,15 +24,7 @@ class MistralPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# Resize embedding
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
|
||||
|
@ -32,6 +32,13 @@ class MistralPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_sequence_parallelism:
|
||||
self.shard_config.enable_sequence_parallelism = False
|
||||
warnings.warn(
|
||||
|
@ -80,10 +87,12 @@ class MistralPolicy(Policy):
|
|||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=MistralModel,
|
||||
|
@ -146,6 +155,8 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
from transformers import MistralForCausalLM
|
||||
|
||||
policy = super().module_policy()
|
||||
if self.pipeline_stage_manager:
|
||||
warnings.warn("Mistral doesn't support pipeline parallelism now.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
# add a new item for casual lm
|
||||
|
@ -153,16 +164,30 @@ class MistralForCausalLMPolicy(MistralPolicy):
|
|||
MistralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=True,
|
||||
make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by,
|
||||
),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
else:
|
||||
new_item = {
|
||||
MistralForCausalLM: ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
||||
)
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
if self.pipeline_stage_manager:
|
||||
warnings.warn("Mistral doesn't support pipeline parallelism now.")
|
||||
|
||||
policy.update(new_item)
|
||||
policy.update(new_item)
|
||||
|
||||
return policy
|
||||
|
||||
|
|
|
@ -5,7 +5,16 @@ from typing import Callable, Dict, List
|
|||
import torch.nn as nn
|
||||
from torch import Tensor, nn
|
||||
|
||||
from colossalai.shardformer.layer import FusedLayerNorm, LayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import (
|
||||
FusedLayerNorm,
|
||||
LayerNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
|
||||
from .._utils import getattr_
|
||||
from ..modeling.jit import get_jit_fused_dropout_add_func
|
||||
|
@ -41,16 +50,7 @@ class OPTPolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -58,6 +58,13 @@ class OPTPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedLayerNorm
|
||||
else:
|
||||
|
@ -68,14 +75,6 @@ class OPTPolicy(Policy):
|
|||
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
policy[OPTDecoder] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
)
|
||||
]
|
||||
)
|
||||
policy[OPTDecoderLayer] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
|
@ -114,6 +113,17 @@ class OPTPolicy(Policy):
|
|||
],
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTDecoder,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -253,8 +263,20 @@ class OPTForCausalLMPolicy(OPTPolicy):
|
|||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=Linear1D_Col,
|
||||
kwargs=dict(gather_output=True),
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs=dict(
|
||||
gather_output=True, make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by
|
||||
),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs=dict(make_vocab_size_divisible_by=self.shard_config.make_vocab_size_divisible_by),
|
||||
),
|
||||
policy=policy,
|
||||
target_key=OPTForCausalLM,
|
||||
|
|
|
@ -13,8 +13,11 @@ from colossalai.shardformer.layer import (
|
|||
FusedRMSNorm,
|
||||
Linear1D_Col,
|
||||
Linear1D_Row,
|
||||
PaddingEmbedding,
|
||||
PaddingLMHead,
|
||||
RMSNorm,
|
||||
VocabParallelEmbedding1D,
|
||||
VocabParallelLMHead1D,
|
||||
)
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription
|
||||
|
||||
|
@ -36,16 +39,7 @@ class T5BasePolicy(Policy):
|
|||
pass
|
||||
|
||||
def preprocess(self):
|
||||
# reshape the embedding layer
|
||||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -61,6 +55,13 @@ class T5BasePolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = FusedRMSNorm
|
||||
else:
|
||||
|
@ -77,10 +78,6 @@ class T5BasePolicy(Policy):
|
|||
suffix="dropout",
|
||||
target_module=DropoutForParallelInput,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
]
|
||||
)
|
||||
policy[T5LayerSelfAttention] = ModulePolicyDescription(
|
||||
|
@ -176,6 +173,17 @@ class T5BasePolicy(Policy):
|
|||
]
|
||||
)
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Stack,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
|
@ -370,11 +378,19 @@ class T5ModelPolicy(T5BasePolicy):
|
|||
|
||||
policy = super().module_policy()
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5Model,
|
||||
|
@ -406,17 +422,44 @@ class T5ForConditionalGenerationPolicy(T5BasePolicy):
|
|||
|
||||
policy = super().module_policy()
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration,
|
||||
)
|
||||
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs=dict(gather_output=True)
|
||||
),
|
||||
],
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration,
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="lm_head",
|
||||
target_module=PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5ForConditionalGeneration,
|
||||
)
|
||||
|
@ -467,11 +510,19 @@ class T5EncoderPolicy(T5BasePolicy):
|
|||
|
||||
policy = super().module_policy()
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = PaddingEmbedding
|
||||
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="shared",
|
||||
target_module=VocabParallelEmbedding1D,
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=policy,
|
||||
target_key=T5EncoderModel,
|
||||
|
|
|
@ -45,11 +45,7 @@ class WhisperPolicy(Policy):
|
|||
r"""
|
||||
Reshape the Embedding layer to make the embedding dimension divisible by world_size
|
||||
"""
|
||||
vocab_size = self.model.config.vocab_size
|
||||
world_size = self.shard_config.tensor_parallel_size
|
||||
if vocab_size % world_size != 0:
|
||||
new_vocab_size = vocab_size + world_size - vocab_size % world_size
|
||||
self.model.resize_token_embeddings(new_vocab_size)
|
||||
self.tie_weight = self.tie_weight_check()
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
|
@ -63,6 +59,13 @@ class WhisperPolicy(Policy):
|
|||
|
||||
policy = {}
|
||||
|
||||
embedding_cls = None
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
embedding_cls = col_nn.VocabParallelEmbedding1D
|
||||
else:
|
||||
if self.tie_weight:
|
||||
embedding_cls = col_nn.PaddingEmbedding
|
||||
|
||||
if self.shard_config.enable_fused_normalization:
|
||||
norm_cls = col_nn.FusedLayerNorm
|
||||
else:
|
||||
|
@ -167,13 +170,17 @@ class WhisperPolicy(Policy):
|
|||
],
|
||||
)
|
||||
|
||||
policy[WhisperDecoder] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
if embedding_cls is not None:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="embed_tokens",
|
||||
target_module=col_nn.VocabParallelEmbedding1D,
|
||||
target_module=embedding_cls,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
]
|
||||
],
|
||||
policy=policy,
|
||||
target_key=WhisperDecoder,
|
||||
)
|
||||
|
||||
# optimization configuration
|
||||
|
@ -280,8 +287,21 @@ class WhisperPolicy(Policy):
|
|||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="proj_out",
|
||||
target_module=col_nn.Linear1D_Col,
|
||||
kwargs={"gather_output": True},
|
||||
target_module=col_nn.VocabParallelLMHead1D,
|
||||
kwargs={
|
||||
"gather_output": True,
|
||||
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
|
||||
},
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=WhisperForConditionalGeneration,
|
||||
)
|
||||
else:
|
||||
self.append_or_create_submodule_replacement(
|
||||
description=SubModuleReplacementDescription(
|
||||
suffix="proj_out",
|
||||
target_module=col_nn.PaddingLMHead,
|
||||
kwargs={"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by},
|
||||
),
|
||||
policy=base_policy,
|
||||
target_key=WhisperForConditionalGeneration,
|
||||
|
@ -526,9 +546,6 @@ class WhisperForConditionalGenerationPolicy(WhisperPolicy):
|
|||
|
||||
# WhisperForAudioClassification
|
||||
class WhisperForAudioClassificationPolicy(WhisperPolicy):
|
||||
def preprocess(self):
|
||||
return self.model
|
||||
|
||||
def module_policy(self):
|
||||
from transformers import WhisperForAudioClassification
|
||||
|
||||
|
|
|
@ -42,10 +42,9 @@ class ShardConfig:
|
|||
sequence_parallelism_mode: str = None
|
||||
enable_sequence_overlap: bool = False
|
||||
parallel_output: bool = True
|
||||
make_vocab_size_divisible_by: int = 64
|
||||
gradient_checkpoint_config: Optional[GradientCheckpointConfig] = None
|
||||
extra_kwargs: Dict[str, Any] = field(default_factory=dict)
|
||||
# TODO padding vocab
|
||||
# make_vocab_size_divisible_by: int = 128
|
||||
# pipeline_parallel_size: int
|
||||
# data_parallel_size: int
|
||||
# tensor_parallel_mode: Literal['1d', '2d', '2.5d', '3d']
|
||||
|
|
|
@ -10,6 +10,7 @@ from colossalai.context.singleton_meta import SingletonMeta
|
|||
from colossalai.tensor.d_tensor.comm_spec import *
|
||||
from colossalai.tensor.d_tensor.layout import Layout
|
||||
from colossalai.tensor.d_tensor.misc import LayoutException
|
||||
from colossalai.tensor.padded_tensor.api import init_as_padded_tensor, is_padded_tensor
|
||||
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator
|
||||
|
||||
from .sharding_spec import ShardingSpec
|
||||
|
@ -607,8 +608,18 @@ class LayoutConverter(metaclass=SingletonMeta):
|
|||
[3.],
|
||||
[3.]])
|
||||
"""
|
||||
|
||||
_, comm_action_sequence = self.layout_converting(source_layout, target_layout)
|
||||
|
||||
target_tensor = tensor
|
||||
for comm_spec in comm_action_sequence:
|
||||
tensor = comm_spec.covert_spec_to_action(tensor)
|
||||
tensor.dist_layout = target_layout
|
||||
return tensor
|
||||
target_tensor = comm_spec.covert_spec_to_action(target_tensor)
|
||||
target_tensor.dist_layout = target_layout
|
||||
|
||||
# restore the padding information
|
||||
if is_padded_tensor(tensor) and not is_padded_tensor(target_tensor):
|
||||
target_tensor = init_as_padded_tensor(
|
||||
target_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
|
||||
)
|
||||
|
||||
return target_tensor
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .api import init_as_padded_tensor, is_padded_tensor, to_padded_tensor, to_unpadded_tensor
|
||||
|
||||
__all__ = ["is_padded_tensor", "to_padded_tensor", "to_unpadded_tensor", "init_as_padded_tensor"]
|
|
@ -0,0 +1,128 @@
|
|||
import torch
|
||||
|
||||
|
||||
def _hijack_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be hijacked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The hijacked tensor.
|
||||
"""
|
||||
ptensor._unpad_detach = ptensor.detach
|
||||
ptensor._unpad_clone = ptensor.clone
|
||||
|
||||
def new_detach(self):
|
||||
t_ = self._unpad_detach()
|
||||
t_._padding_dim = self._padding_dim
|
||||
t_._origin_length = self._origin_length
|
||||
t_._current_length = self._current_length
|
||||
return t_
|
||||
|
||||
def new_clone(self, *args, **kwargs):
|
||||
t_ = self._unpad_clone(*args, **kwargs)
|
||||
t_._padding_dim = self._padding_dim
|
||||
t_._origin_length = self._origin_length
|
||||
t_._current_length = self._current_length
|
||||
return t_
|
||||
|
||||
# bind the new methods to the tensor
|
||||
ptensor.detach = new_detach.__get__(ptensor)
|
||||
ptensor.clone = new_clone.__get__(ptensor)
|
||||
return ptensor
|
||||
|
||||
|
||||
def _hijack_back_detach_and_clone(ptensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Hijack the detach and clone methods of the tensor to make sure the dist_layout is copied.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be hijacked.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The hijacked tensor.
|
||||
"""
|
||||
ptensor.detach = ptensor._unpad_detach
|
||||
ptensor.clone = ptensor._unpad_clone
|
||||
|
||||
delattr(ptensor, "_unpad_detach")
|
||||
delattr(ptensor, "_unpad_clone")
|
||||
|
||||
return ptensor
|
||||
|
||||
|
||||
def is_padded_tensor(tensor: torch.Tensor) -> bool:
|
||||
"""
|
||||
Check whether the given tensor is a padding tensor.
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): The tensor to be checked.
|
||||
|
||||
Returns:
|
||||
bool: Whether the given tensor is a padding tensor.
|
||||
"""
|
||||
return hasattr(tensor, "_padding_dim")
|
||||
|
||||
|
||||
def to_padded_tensor(
|
||||
tensor: torch.Tensor,
|
||||
current_length: int,
|
||||
padding_dim: int,
|
||||
) -> torch.Tensor:
|
||||
assert (
|
||||
padding_dim < tensor.dim()
|
||||
), f"Please passing a valid padding_dim. the dimension of the tensor is {tensor.dim()}"
|
||||
|
||||
if is_padded_tensor(tensor):
|
||||
return tensor
|
||||
|
||||
origin_length = tensor.shape[padding_dim]
|
||||
padding_num = current_length - origin_length
|
||||
padding_data = torch.zeros(
|
||||
*tensor.shape[:padding_dim],
|
||||
padding_num,
|
||||
*tensor.shape[padding_dim + 1 :],
|
||||
device=tensor.device,
|
||||
dtype=tensor.dtype,
|
||||
)
|
||||
tensor.data = torch.cat((tensor.data, padding_data), dim=padding_dim).contiguous()
|
||||
|
||||
tensor._padding_dim = padding_dim
|
||||
tensor._origin_length = origin_length
|
||||
tensor._current_length = current_length
|
||||
|
||||
_hijack_detach_and_clone(tensor)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def to_unpadded_tensor(ptensor: torch.Tensor):
|
||||
if not is_padded_tensor(ptensor):
|
||||
return ptensor
|
||||
|
||||
unpad_slices = [slice(None)] * ptensor.dim()
|
||||
unpad_slices[ptensor._padding_dim] = slice(None, ptensor._origin_length)
|
||||
ptensor.data = ptensor.data[tuple(unpad_slices)]
|
||||
|
||||
delattr(ptensor, "_padding_dim")
|
||||
delattr(ptensor, "_origin_length")
|
||||
delattr(ptensor, "_current_length")
|
||||
|
||||
_hijack_back_detach_and_clone(ptensor)
|
||||
|
||||
return ptensor
|
||||
|
||||
|
||||
def init_as_padded_tensor(tensor: torch.Tensor, current_length: int, origin_length: int, padding_dim: int):
|
||||
if is_padded_tensor(tensor):
|
||||
return tensor
|
||||
|
||||
tensor._padding_dim = padding_dim
|
||||
tensor._origin_length = origin_length
|
||||
tensor._current_length = current_length
|
||||
|
||||
_hijack_detach_and_clone(tensor)
|
||||
|
||||
return tensor
|
|
@ -23,7 +23,7 @@ def assert_close_loose(a: Tensor, b: Tensor, rtol: float = 1e-3, atol: float = 1
|
|||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"Tensor not close, shape: {a.shape} vs {b.shape}, \
|
||||
dtype: {a.dtype} vs {b.dtype}",
|
||||
dtype: {a.dtype} vs {b.dtype}",
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -27,6 +27,12 @@ from colossalai.tensor.d_tensor import (
|
|||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.tensor.padded_tensor import (
|
||||
init_as_padded_tensor,
|
||||
is_padded_tensor,
|
||||
to_padded_tensor,
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import _cast_float, free_storage, is_ddp_ignored
|
||||
|
||||
|
@ -460,6 +466,11 @@ class GeminiDDP(ModelWrapper):
|
|||
record_tensor, shard_fn=tensor.shard_fn, gather_fn=tensor.gather_fn
|
||||
)
|
||||
record_tensor = gather_distributed_param(record_tensor, keep_vars=False).cpu()
|
||||
if is_padded_tensor(tensor):
|
||||
record_tensor = init_as_padded_tensor(
|
||||
record_tensor, tensor._current_length, tensor._origin_length, tensor._padding_dim
|
||||
)
|
||||
record_tensor = to_unpadded_tensor(record_tensor)
|
||||
|
||||
assert tensor not in chunk_to_save_data
|
||||
chunk_to_save_data[tensor] = record_tensor
|
||||
|
@ -520,6 +531,8 @@ class GeminiDDP(ModelWrapper):
|
|||
# deal with ddp ignored parameters
|
||||
destination[prefix + name] = param if keep_vars else param.detach()
|
||||
else:
|
||||
if is_padded_tensor(p_mapping[param]):
|
||||
p_mapping[param] = to_unpadded_tensor(p_mapping[param])
|
||||
destination[prefix + name] = p_mapping[param]
|
||||
del p_mapping
|
||||
del param_to_save_data
|
||||
|
@ -627,6 +640,7 @@ class GeminiDDP(ModelWrapper):
|
|||
list, and will be reported together in
|
||||
:meth:`~torch.nn.Module.load_state_dict`
|
||||
"""
|
||||
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
|
@ -647,6 +661,14 @@ class GeminiDDP(ModelWrapper):
|
|||
if state_key in state_dict:
|
||||
input_param = state_dict[state_key]
|
||||
|
||||
global_shape = dest_tensor.shape
|
||||
if source_device_mesh is not None and source_sharding_spec is not None:
|
||||
global_shape = get_global_shape(dest_tensor)
|
||||
|
||||
if is_padded_tensor(dest_tensor):
|
||||
padding_dim = dest_tensor._padding_dim
|
||||
input_param = to_padded_tensor(input_param, global_shape[padding_dim], padding_dim)
|
||||
|
||||
if source_device_mesh is not None and source_sharding_spec is not None:
|
||||
input_param = distribute_tensor(input_param, source_device_mesh, source_sharding_spec)
|
||||
elif shard_fn is not None and gather_fn is not None:
|
||||
|
|
|
@ -21,12 +21,19 @@ from colossalai.tensor.d_tensor import (
|
|||
distribute_tensor,
|
||||
distribute_tensor_with_customization,
|
||||
get_device_mesh,
|
||||
get_global_shape,
|
||||
get_sharding_spec,
|
||||
init_as_dtensor,
|
||||
init_tensor_as_customization_distributed,
|
||||
is_customized_distributed_tensor,
|
||||
is_distributed_tensor,
|
||||
)
|
||||
from colossalai.tensor.padded_tensor import (
|
||||
init_as_padded_tensor,
|
||||
is_padded_tensor,
|
||||
to_padded_tensor,
|
||||
to_unpadded_tensor,
|
||||
)
|
||||
from colossalai.utils import disposable, is_ddp_ignored
|
||||
|
||||
from .chunk import Chunk, ChunkManager
|
||||
|
@ -106,7 +113,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0,
|
||||
tp_group: ProcessGroup = None,
|
||||
optimizer_params_info=None,
|
||||
params_info=None,
|
||||
verbose: bool = False,
|
||||
**defaults: Any,
|
||||
):
|
||||
|
@ -124,7 +131,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
self.clipping_flag = max_norm > 0.0
|
||||
self.max_norm = max_norm
|
||||
self.tp_group = tp_group
|
||||
self.optimizer_params_info = optimizer_params_info
|
||||
self.params_info = params_info
|
||||
self.tp_size = dist.get_world_size(tp_group) if tp_group is not None else 1
|
||||
self.tp_rank = dist.get_rank(tp_group) if tp_group is not None else 0
|
||||
self.verbose = verbose
|
||||
|
@ -459,7 +466,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
is_customized_distributed = is_customized_distributed_tensor(param)
|
||||
shard_spec = get_sharding_spec(param) if is_dtensor else None
|
||||
device_mesh = get_device_mesh(param) if is_dtensor else None
|
||||
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
||||
global_shape = self.params_info["id2shape"][param_id]
|
||||
|
||||
# If the chunk is kept gathered,
|
||||
# the parameters are treated the same as that of those in strict DDP during training.
|
||||
|
@ -477,6 +484,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
else:
|
||||
state_tensor = states[state_name].detach().clone().to(torch.float32).cpu()
|
||||
if is_dtensor:
|
||||
global_shape = get_global_shape(param)
|
||||
state_tensor = torch.reshape(state_tensor, param.shape).to(param.device)
|
||||
state_tensor = init_as_dtensor(
|
||||
state_tensor,
|
||||
|
@ -490,8 +498,13 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||
)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
|
||||
collected_states[state_name] = state_tensor.reshape(global_shape)
|
||||
state_tensor = state_tensor.reshape(global_shape)
|
||||
if is_padded_tensor(param):
|
||||
state_tensor = init_as_padded_tensor(
|
||||
state_tensor, param._current_length, param._origin_length, param._padding_dim
|
||||
)
|
||||
state_tensor = to_unpadded_tensor(state_tensor)
|
||||
collected_states[state_name] = state_tensor
|
||||
return collected_states
|
||||
|
||||
# Check whether the param with given id is managed by current process.
|
||||
|
@ -535,6 +548,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
if state_tensor.numel() == param.numel():
|
||||
collected_states[state_name] = torch.reshape(state_tensor, param.shape)
|
||||
if is_dtensor:
|
||||
global_shape = get_global_shape(param)
|
||||
state_tensor = state_tensor.to(param.device)
|
||||
state_tensor = init_as_dtensor(
|
||||
state_tensor, sharding_spec=shard_spec, device_mesh=device_mesh, global_shape=global_shape
|
||||
|
@ -545,6 +559,11 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
state_tensor, shard_fn=param.shard_fn, gather_fn=param.gather_fn
|
||||
)
|
||||
state_tensor = gather_distributed_param(state_tensor, keep_vars=False).cpu()
|
||||
if is_padded_tensor(param):
|
||||
state_tensor = init_as_padded_tensor(
|
||||
state_tensor, param._current_length, param._origin_length, param._padding_dim
|
||||
)
|
||||
state_tensor = to_unpadded_tensor(state_tensor)
|
||||
|
||||
return collected_states
|
||||
|
||||
|
@ -698,7 +717,7 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
Load saved optimizer states into parameter with given id.
|
||||
"""
|
||||
|
||||
def cast(param, state_range, value, key=None):
|
||||
def cast(param, state_range, value, global_shape, origin_shape, key=None):
|
||||
"""
|
||||
Make a copy of the needed segment of value and cast it to device of param.
|
||||
"""
|
||||
|
@ -714,7 +733,14 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
)
|
||||
|
||||
if is_dtensor:
|
||||
value = torch.reshape(value, global_shape)
|
||||
global_shape = get_global_shape(real_param)
|
||||
|
||||
if is_padded_tensor(real_param):
|
||||
value = torch.reshape(value, origin_shape)
|
||||
padding_dim = real_param._padding_dim
|
||||
value = to_padded_tensor(value, global_shape[padding_dim], padding_dim)
|
||||
|
||||
if is_dtensor:
|
||||
value = distribute_tensor(value, sharding_spec=shard_spec, device_mesh=device_mesh)
|
||||
elif is_customized_distributed:
|
||||
value = torch.reshape(value, global_shape)
|
||||
|
@ -737,10 +763,11 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
is_customized_distributed = is_customized_distributed_tensor(real_param)
|
||||
shard_spec = get_sharding_spec(real_param) if is_dtensor else None
|
||||
device_mesh = get_device_mesh(real_param) if is_dtensor else None
|
||||
global_shape = self.optimizer_params_info["id2shape"][param_id]
|
||||
global_shape = self.params_info["id2shape"][param_id]
|
||||
origin_shape = global_shape
|
||||
|
||||
for k, v in saved_states.items():
|
||||
updated_states[k] = cast(fake_param, state_range, v, k)
|
||||
updated_states[k] = cast(fake_param, state_range, v, global_shape, origin_shape, k)
|
||||
del v # clean loaded states
|
||||
self.optim.state[fake_param].update(updated_states)
|
||||
|
||||
|
|
|
@ -81,8 +81,7 @@ def exam_state_dict(shard: bool, model_name: str, size_per_shard: int, test_conf
|
|||
optimizer.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
for group in optimizer.param_groups:
|
||||
group["lr"] = 0.1
|
||||
optimizer.zero_grad()
|
||||
with shared_tempdir() as tempdir:
|
||||
model_ckpt_path = f"{tempdir}/model"
|
||||
optimizer_ckpt_path = f"{tempdir}/optimizer"
|
||||
|
|
|
@ -21,7 +21,7 @@ def check_vocab_embedding_1d(lazy_init: bool):
|
|||
dist_embedding_1d = VocabParallelEmbedding1D.from_native_module(embedding_copy, process_group=None)
|
||||
|
||||
assert dist_embedding_1d.weight.shape == torch.Size([64, 32])
|
||||
assert dist_embedding_1d.num_embeddings == 64
|
||||
assert dist_embedding_1d.num_embeddings == 128
|
||||
assert dist_embedding_1d.embedding_dim == 32
|
||||
assert embedding_copy.weight is dist_embedding_1d.weight
|
||||
|
||||
|
|
|
@ -14,12 +14,14 @@ from torch.testing import assert_close
|
|||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import HybridParallelPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.checkpoint_io.utils import gather_distributed_param
|
||||
from colossalai.lazy import LazyInitContext
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.shardformer._utils import getattr_
|
||||
from colossalai.shardformer.policies.auto_policy import Policy
|
||||
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
||||
from colossalai.tensor.padded_tensor.api import is_padded_tensor, to_unpadded_tensor
|
||||
|
||||
|
||||
def build_model(
|
||||
|
@ -247,11 +249,10 @@ def check_weight(
|
|||
continue
|
||||
|
||||
if is_distributed_tensor(sharded_weight) or is_customized_distributed_tensor(sharded_weight):
|
||||
sharded_weight_list = [
|
||||
torch.zeros_like(sharded_weight).to("cuda") for _ in range(dist.get_world_size(tp_group))
|
||||
]
|
||||
dist.all_gather(sharded_weight_list, sharded_weight, tp_group)
|
||||
sharded_weight = torch.cat(sharded_weight_list, dim=dim)
|
||||
sharded_weight = gather_distributed_param(sharded_weight, keep_vars=False)
|
||||
|
||||
if is_padded_tensor(sharded_weight):
|
||||
sharded_weight = to_unpadded_tensor(sharded_weight)
|
||||
|
||||
if verbose and dist.get_rank() == 0:
|
||||
print(f"'{suffix}' weight: {org_weight}, {sharded_weight}")
|
||||
|
|
|
@ -73,7 +73,8 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
|
||||
# check weights
|
||||
if test_config["precision"] == "fp32":
|
||||
atol, rtol = 5e-4, 1e-3
|
||||
# TODO he precision in weight checking is too significant.
|
||||
atol, rtol = 1e-3, 1e-3
|
||||
else:
|
||||
atol, rtol = 5e-3, 5e-3
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
|
|
|
@ -0,0 +1,46 @@
|
|||
import torch
|
||||
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.initialize import launch
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.tensor.d_tensor import ShardingSpec, distribute_tensor, is_distributed_tensor, to_global
|
||||
from colossalai.tensor.padded_tensor import is_padded_tensor, to_padded_tensor, to_unpadded_tensor
|
||||
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def check_padded_tensor(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
original_tensor = torch.rand(32, 64).to("cuda")
|
||||
|
||||
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
|
||||
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict={0: [0]})
|
||||
d_tensor = distribute_tensor(original_tensor, device_mesh, target_sharding_spec)
|
||||
|
||||
padded_tensor = to_padded_tensor(d_tensor, current_length=64, padding_dim=0)
|
||||
assert padded_tensor.dist_layout == d_tensor.dist_layout
|
||||
|
||||
tensor_copy = padded_tensor.clone()
|
||||
assert is_padded_tensor(tensor_copy)
|
||||
assert is_distributed_tensor(tensor_copy)
|
||||
|
||||
tensor_detached = padded_tensor.detach()
|
||||
assert is_padded_tensor(tensor_detached)
|
||||
assert is_distributed_tensor(tensor_detached)
|
||||
|
||||
unpadded_tensor = to_unpadded_tensor(padded_tensor)
|
||||
assert unpadded_tensor.shape == d_tensor.shape
|
||||
assert is_distributed_tensor(unpadded_tensor)
|
||||
|
||||
global_tensor = to_global(unpadded_tensor)
|
||||
assert global_tensor.shape == original_tensor.shape
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_padded_tensor():
|
||||
world_size = 4
|
||||
spawn(check_padded_tensor, world_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_padded_tensor()
|
Loading…
Reference in New Issue