2023-06-07 08:09:40 +00:00
|
|
|
from typing import Any, Callable, Dict, List, Tuple, Type
|
|
|
|
|
|
|
|
import torch.nn as nn
|
|
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
|
|
|
|
|
|
|
|
import colossalai.shardformer.layer.layers as col_nn
|
|
|
|
|
|
|
|
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
|
|
|
|
|
|
|
|
|
|
|
|
class GPT2Policy(Policy):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def argument_policy(config, world_size):
|
|
|
|
return {
|
|
|
|
GPT2Model:
|
|
|
|
Argument(attr_dict={}, param_funcs=[
|
|
|
|
GPT2Policy.embedding,
|
|
|
|
]),
|
|
|
|
GPT2Block:
|
|
|
|
Argument(
|
|
|
|
attr_dict={
|
|
|
|
# 1. reduce hidden size
|
|
|
|
"attn.embed_dim": config.hidden_size // world_size,
|
|
|
|
"attn.split_size": config.hidden_size // world_size,
|
|
|
|
"crossattention.embed_dim": config.hidden_size // world_size,
|
|
|
|
"crossattention.split_size": config.hidden_size // world_size,
|
|
|
|
# 2. reduce number of heads
|
|
|
|
"attn.num_heads": config.num_attention_heads // world_size,
|
|
|
|
"crossattention.num_heads": config.num_attention_heads // world_size,
|
|
|
|
},
|
|
|
|
param_funcs=[
|
|
|
|
GPT2Policy.attn_in,
|
|
|
|
GPT2Policy.attn_out,
|
|
|
|
GPT2Policy.mlp_in,
|
|
|
|
GPT2Policy.mlp_out,
|
|
|
|
]),
|
|
|
|
}
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def attn_in() -> List:
|
|
|
|
return [
|
2023-06-12 08:52:18 +00:00
|
|
|
Col_Layer(suffix="attn.c_attn",
|
|
|
|
weight="weight",
|
|
|
|
bias="bias",
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast=3,
|
|
|
|
reversed=True,
|
|
|
|
replace_layer=col_nn.Linear1D_Col),
|
2023-06-12 08:52:18 +00:00
|
|
|
Col_Layer(suffix="crossattention.c_attn",
|
|
|
|
weight="weight",
|
|
|
|
bias="bias",
|
2023-06-07 08:09:40 +00:00
|
|
|
n_cast=2,
|
|
|
|
reversed=True,
|
|
|
|
ignore=True,
|
|
|
|
replace_layer=col_nn.Linear1D_Col),
|
2023-06-12 08:52:18 +00:00
|
|
|
Col_Layer(suffix="crossattention.q_attn",
|
|
|
|
weight="weight",
|
|
|
|
bias="bias",
|
2023-06-07 08:09:40 +00:00
|
|
|
reversed=True,
|
|
|
|
ignore=True,
|
|
|
|
replace_layer=col_nn.Linear1D_Col)
|
|
|
|
]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def attn_out() -> List:
|
|
|
|
return [
|
2023-06-12 08:52:18 +00:00
|
|
|
Row_Layer(suffix="attn.c_proj",
|
|
|
|
weight="weight",
|
|
|
|
bias="bias",
|
2023-06-07 08:09:40 +00:00
|
|
|
reversed=True,
|
|
|
|
replace_layer=col_nn.Linear1D_Row),
|
2023-06-12 08:52:18 +00:00
|
|
|
Row_Layer(suffix="crossattention.c_proj",
|
|
|
|
weight="weight",
|
|
|
|
bias="bias",
|
2023-06-07 08:09:40 +00:00
|
|
|
reversed=True,
|
|
|
|
ignore=True,
|
|
|
|
replace_layer=col_nn.Linear1D_Row)
|
|
|
|
]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def mlp_in() -> List:
|
|
|
|
return [
|
2023-06-12 08:52:18 +00:00
|
|
|
Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True,
|
|
|
|
replace_layer=col_nn.Linear1D_Col),
|
2023-06-07 08:09:40 +00:00
|
|
|
]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def mlp_out() -> List:
|
|
|
|
return [
|
2023-06-12 08:52:18 +00:00
|
|
|
Row_Layer(suffix="mlp.c_proj",
|
|
|
|
weight="weight",
|
|
|
|
bias="bias",
|
2023-06-07 08:09:40 +00:00
|
|
|
reversed=True,
|
|
|
|
replace_layer=col_nn.Linear1D_Row)
|
|
|
|
]
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def embedding() -> List:
|
2023-06-12 08:52:18 +00:00
|
|
|
return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
|
2023-06-07 08:09:40 +00:00
|
|
|
|
|
|
|
|
|
|
|
from transformers import GPT2LMHeadModel
|
|
|
|
|
|
|
|
|
|
|
|
class GPT2LMHeadModelPolicy(GPT2Policy):
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def argument_policy(config, world_size):
|
|
|
|
base_argument = GPT2Policy.argument_policy(config, world_size)
|
|
|
|
argument = {
|
|
|
|
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
|
|
|
|
GPT2LMHeadModelPolicy.unembedding,
|
|
|
|
]),
|
|
|
|
}
|
|
|
|
argument.update(base_argument)
|
|
|
|
return argument
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
def unembedding() -> List:
|
|
|
|
return [
|
2023-06-12 08:52:18 +00:00
|
|
|
Col_Layer(suffix="lm_head",
|
|
|
|
weight="weight",
|
|
|
|
bias="bias",
|
2023-06-07 08:09:40 +00:00
|
|
|
replace_layer=col_nn.Linear1D_Col,
|
|
|
|
gather_output=True)
|
|
|
|
]
|