ColossalAI/examples/language/openmoe/model/convert_openmoe_ckpt.py

225 lines
9.1 KiB
Python

# coding=utf-8
# Copyright 2022 Google LLC and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Convert T5X checkpoint to PyTorch
Steps:
- Install gsutil according to https://cloud.google.com/storage/docs/gsutil_install
- Get a T5X checkpoint at https://github.com/google-research/t5x/blob/main/docs/models.md#t5-11-checkpoints Example:
`gsutil -m cp -r gs://t5-data/pretrained_models/t5x/t5_1_1_small $HOME/`
- Create or download a corresponding config for the downloaded model. E.g. for T5 v1.1 small, you can use
https://huggingface.co/google/t5-v1_1-small/blob/main/config.json
- Convert:
```
python3 convert_t5x_checkpoint_to_pytorch.py --t5x_checkpoint_path=$HOME/t5_1_1_small --config_file=config.json\
--pytorch_dump_path=$HOME/t5_1_1_small_pt
```
"""
import argparse
import collections
import torch
from flax import traverse_util
from modeling_openmoe import OpenMoeForCausalLM
from t5x import checkpoints
from transformers import LlamaConfig
from transformers.utils import logging
logging.set_verbosity_info()
def t5x_attention_lookup(params, i, prefix, layer_name="attention"):
"""Returns the KOQV parameters of (self-)attention. Does not transpose."""
k = params[f"{prefix}/layers_{i}/{layer_name}/key/kernel"]
o = params[f"{prefix}/layers_{i}/{layer_name}/out/kernel"]
q = params[f"{prefix}/layers_{i}/{layer_name}/query/kernel"]
v = params[f"{prefix}/layers_{i}/{layer_name}/value/kernel"]
return k, o, q, v
def t5x_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/wo/kernel"]
return wi, wo
def t5x_extra_mlp_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/extra_mlp/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/extra_mlp/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/extra_mlp/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/extra_mlp/wo/kernel"]
return wi, wo
def t5x_experts_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
if split_mlp_wi:
wi_0 = params[f"{prefix}/layers_{i}/mlp/expert/wi_0/kernel"]
wi_1 = params[f"{prefix}/layers_{i}/mlp/expert/wi_1/kernel"]
wi = (wi_0, wi_1)
else:
wi = params[f"{prefix}/layers_{i}/mlp/expert/wi/kernel"]
wo = params[f"{prefix}/layers_{i}/mlp/expert/wo/kernel"]
return wi, wo
def t5x_gate_lookup(params, i, prefix, split_mlp_wi=False):
"""Returns the MLP parameters of a layer. Does not transpose."""
return params[f"{prefix}/layers_{i}/mlp/router/router_weights/w/kernel"]
def t5x_layer_norm_lookup(params, i, prefix, layer_name):
"""Returns the layer norm param of a layer."""
return params[f"{prefix}/layers_{i}/{layer_name}/scale"]
def convert_t5x_to_pytorch(variables: dict, *, num_layers: int, moe_interval: int):
"""Converts the parameters from T5X-Flax to Transformers-PyTorch."""
old = traverse_util.flatten_dict(variables["target"])
old = {"/".join(k): v for k, v in old.items()}
# v1.1 models have a gated GeLU with wi_0 and wi_1 instead of wi
split_mlp_wi = True
print("Split MLP:", split_mlp_wi)
new = collections.OrderedDict()
print(old.keys())
for key, value in old.items():
print(f"{key}: {value.shape}")
# Shared embeddings.
new["model.embed_tokens.weight"] = old["token_embedder/embedding"]
# Decoder.
for i in range(num_layers):
# Block i, layer 0 (Self Attention).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_self_attention_layer_norm")
k, o, q, v = t5x_attention_lookup(old, i, "decoder", "self_attention")
new[f"model.layers.{i}.input_layernorm.weight"] = layer_norm
new[f"model.layers.{i}.self_attn.k_proj.weight"] = k.T
new[f"model.layers.{i}.self_attn.o_proj.weight"] = o.T
new[f"model.layers.{i}.self_attn.q_proj.weight"] = q.T
new[f"model.layers.{i}.self_attn.v_proj.weight"] = v.T
# Block i, layer 2 (MLP).
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_mlp_layer_norm")
new[f"model.layers.{i}.post_attention_layernorm.weight"] = layer_norm
if (i + 1) % moe_interval == 0:
# moe
gate = t5x_gate_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.gate_weight"] = gate.T
wi, wo = t5x_experts_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.experts.wi_gate"] = wi[0]
new[f"model.layers.{i}.mlp.experts.wi_up"] = wi[1]
new[f"model.layers.{i}.mlp.experts.wo"] = wo
# extra
layer_norm = t5x_layer_norm_lookup(old, i, "decoder", "pre_extra_mlp_layer_norm")
new[f"model.layers.{i}.pre_extra_mlp_layernorm.weight"] = layer_norm
wi, wo = t5x_extra_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.extra_mlp.gate_proj.weight"] = wi[0].T
new[f"model.layers.{i}.extra_mlp.up_proj.weight"] = wi[1].T
new[f"model.layers.{i}.extra_mlp.down_proj.weight"] = wo.T
else:
wi, wo = t5x_mlp_lookup(old, i, "decoder", split_mlp_wi)
new[f"model.layers.{i}.mlp.gate_proj.weight"] = wi[0].T
new[f"model.layers.{i}.mlp.up_proj.weight"] = wi[1].T
new[f"model.layers.{i}.mlp.down_proj.weight"] = wo.T
new["model.norm.weight"] = old["decoder/decoder_norm/scale"]
# LM Head (only in v1.1 checkpoints, in v1.0 embeddings are used instead)
if "decoder/logits_dense/kernel" in old:
new["lm_head.weight"] = old["decoder/logits_dense/kernel"].T
return new
def make_state_dict(converted_params):
"""Prepares a state dict for the PyTorch model."""
# Make a state dict with torch tensors.
state_dict = collections.OrderedDict([(k, torch.from_numpy(v.copy())) for (k, v) in converted_params.items()])
return state_dict
def load_t5x_weights_in_t5(model, config, t5x_checkpoint_path):
"""Replaces the params in model witht the T5X converted params."""
variables = checkpoints.load_t5x_checkpoint(t5x_checkpoint_path)
converted = convert_t5x_to_pytorch(variables,
num_layers=config.num_hidden_layers,
moe_interval=config.moe_layer_interval)
state_dict = make_state_dict(converted)
model.load_state_dict(state_dict, strict=True)
def convert_t5x_checkpoint_to_pytorch(t5x_checkpoint_path, config_file, pytorch_dump_path):
"""Loads the config and model, converts the T5X checkpoint, and saves a PyTorch checkpoint."""
# Initialise PyTorch model
config = LlamaConfig.from_json_file(config_file)
print(f"Building PyTorch model from configuration: {config}")
# Non-v1.1 checkpoints could also use T5Model, but this works for all.
# The v1.0 checkpoints will simply have an LM head that is the word embeddings.
model = OpenMoeForCausalLM(config)
# Load weights from tf checkpoint
load_t5x_weights_in_t5(model, config, t5x_checkpoint_path)
# Save pytorch-model
print(f"Save PyTorch model to {pytorch_dump_path}")
model.save_pretrained(pytorch_dump_path)
# Verify that we can load the checkpoint.
model.from_pretrained(pytorch_dump_path)
print("Done")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Converts a native T5X checkpoint into a PyTorch checkpoint.")
# Required parameters
parser.add_argument("--t5x_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the T5X checkpoint.")
parser.add_argument(
"--config_file",
default=None,
type=str,
required=True,
help="The config json file corresponding to the pre-trained T5 model.\nThis specifies the model architecture.",
)
parser.add_argument("--pytorch_dump_path",
default=None,
type=str,
required=True,
help="Path to the output PyTorch model.")
args = parser.parse_args()
convert_t5x_checkpoint_to_pytorch(args.t5x_checkpoint_path, args.config_file, args.pytorch_dump_path)