ColossalAI/colossalai/shardformer/policies/gpt2.py

127 lines
4.1 KiB
Python
Raw Normal View History

from typing import Any, Callable, Dict, List, Tuple, Type
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model
import colossalai.shardformer.layer.layers as col_nn
from .basepolicy import Argument, Col_Layer, Layer, Policy, Row_Layer
class GPT2Policy(Policy):
@staticmethod
def argument_policy(config, world_size):
return {
GPT2Model:
Argument(attr_dict={}, param_funcs=[
GPT2Policy.embedding,
]),
GPT2Block:
Argument(
attr_dict={
# 1. reduce hidden size
"attn.embed_dim": config.hidden_size // world_size,
"attn.split_size": config.hidden_size // world_size,
"crossattention.embed_dim": config.hidden_size // world_size,
"crossattention.split_size": config.hidden_size // world_size,
# 2. reduce number of heads
"attn.num_heads": config.num_attention_heads // world_size,
"crossattention.num_heads": config.num_attention_heads // world_size,
},
param_funcs=[
GPT2Policy.attn_in,
GPT2Policy.attn_out,
GPT2Policy.mlp_in,
GPT2Policy.mlp_out,
]),
}
@staticmethod
def attn_in() -> List:
return [
Col_Layer(suffix="attn.c_attn",
weight="weight",
bias="bias",
n_cast=3,
reversed=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(suffix="crossattention.c_attn",
weight="weight",
bias="bias",
n_cast=2,
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col),
Col_Layer(suffix="crossattention.q_attn",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Col)
]
@staticmethod
def attn_out() -> List:
return [
Row_Layer(suffix="attn.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row),
Row_Layer(suffix="crossattention.c_proj",
weight="weight",
bias="bias",
reversed=True,
ignore=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def mlp_in() -> List:
return [
Col_Layer(suffix="mlp.c_fc", weight="weight", bias="bias", reversed=True,
replace_layer=col_nn.Linear1D_Col),
]
@staticmethod
def mlp_out() -> List:
return [
Row_Layer(suffix="mlp.c_proj",
weight="weight",
bias="bias",
reversed=True,
replace_layer=col_nn.Linear1D_Row)
]
@staticmethod
def embedding() -> List:
return [Col_Layer(suffix="wte", weight="weight", replace_layer=col_nn.VocabParallelEmbedding1D)]
from transformers import GPT2LMHeadModel
class GPT2LMHeadModelPolicy(GPT2Policy):
@staticmethod
def argument_policy(config, world_size):
base_argument = GPT2Policy.argument_policy(config, world_size)
argument = {
GPT2LMHeadModel: Argument(attr_dict={}, param_funcs=[
GPT2LMHeadModelPolicy.unembedding,
]),
}
argument.update(base_argument)
return argument
@staticmethod
def unembedding() -> List:
return [
Col_Layer(suffix="lm_head",
weight="weight",
bias="bias",
replace_layer=col_nn.Linear1D_Col,
gather_output=True)
]