mirror of https://github.com/hpcaitech/ColossalAI
Frank Lee
1 year ago
24 changed files with 242 additions and 171 deletions
@ -1,5 +1,6 @@
|
||||
from .albert import * |
||||
from .bert import * |
||||
from .gpt import * |
||||
from .llama import * |
||||
from .opt import * |
||||
from .t5 import * |
||||
|
@ -0,0 +1,76 @@
|
||||
import torch |
||||
import transformers |
||||
|
||||
from ..registry import ModelAttribute, model_zoo |
||||
|
||||
try: |
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel |
||||
HAS_LLAMA = True |
||||
except ImportError: |
||||
HAS_LLAMA = False |
||||
|
||||
if HAS_LLAMA: |
||||
# =============================== |
||||
# Register LLaMA |
||||
# =============================== |
||||
|
||||
def data_gen(): |
||||
# the input ids are corresponding to the sentence |
||||
# 'Hello, my dog is cute' |
||||
# |
||||
# the code is give below: |
||||
# ----------------------------------- |
||||
# from transformers import LlamaTokenizerFast |
||||
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") |
||||
# input = 'Hello, my dog is cute' |
||||
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') |
||||
# ----------------------------------- |
||||
|
||||
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() |
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() |
||||
return dict(input_ids=input_ids, attention_mask=attention_mask) |
||||
|
||||
# label is needed for casual lm |
||||
def data_gen_for_casual_lm(): |
||||
data = data_gen() |
||||
labels = data['input_ids'].clone() |
||||
data['labels'] = labels |
||||
return data |
||||
|
||||
# transform the output to a dict |
||||
output_transform_fn = lambda x: x |
||||
|
||||
# function to get the loss |
||||
loss_fn = lambda output: output.last_hidden_state.mean() |
||||
loss_fn_for_casual_lm = lambda output: output.loss |
||||
loss_fn_for_seq_classification = lambda output: output.logits.mean() |
||||
|
||||
config = LlamaConfig(num_hidden_layers=4, |
||||
hidden_size=128, |
||||
intermediate_size=256, |
||||
num_attention_heads=4, |
||||
max_position_embeddings=128, |
||||
num_labels=16) |
||||
|
||||
# register the following models |
||||
# transformers.LlamaModel, |
||||
# transformers.LlamaForCausalLM, |
||||
# transformers.LlamaForSequenceClassification, |
||||
model_zoo.register(name='transformers_llama', |
||||
model_fn=lambda: transformers.LlamaModel(config), |
||||
data_gen_fn=data_gen, |
||||
output_transform_fn=output_transform_fn, |
||||
loss_fn=loss_fn, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
||||
model_zoo.register(name='transformers_llama_for_casual_lm', |
||||
model_fn=lambda: transformers.LlamaForCausalLM(config), |
||||
data_gen_fn=data_gen_for_casual_lm, |
||||
output_transform_fn=output_transform_fn, |
||||
loss_fn=loss_fn_for_casual_lm, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
||||
model_zoo.register(name='transformers_llama_for_sequence_classification', |
||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config), |
||||
data_gen_fn=data_gen, |
||||
output_transform_fn=output_transform_fn, |
||||
loss_fn=loss_fn_for_seq_classification, |
||||
model_attribute=ModelAttribute(has_control_flow=True)) |
@ -0,0 +1,38 @@
|
||||
import copy |
||||
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer |
||||
|
||||
|
||||
def build_model(world_size, model_fn): |
||||
# create new model |
||||
org_model = model_fn().cuda() |
||||
|
||||
# shard model |
||||
shard_config = ShardConfig(tensor_parallel_size=world_size) |
||||
model_copy = copy.deepcopy(org_model) |
||||
shard_former = ShardFormer(shard_config=shard_config) |
||||
shard_former.init_distributed() |
||||
sharded_model = shard_former.shard_model(model_copy) |
||||
|
||||
return org_model, sharded_model |
||||
|
||||
|
||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): |
||||
# prepare input |
||||
data = data_gen_fn() |
||||
data = {k: v.cuda() for k, v in data.items()} |
||||
|
||||
# switch to train mode |
||||
original_model.train() |
||||
sharded_model.train() |
||||
|
||||
# run forward |
||||
org_output = original_model(**data) |
||||
org_output = output_transform_fn(org_output) |
||||
org_loss = loss_fn(org_output) |
||||
|
||||
shard_output = sharded_model(**data) |
||||
shard_output = output_transform_fn(shard_output) |
||||
shard_loss = loss_fn(shard_output) |
||||
|
||||
return org_output, org_loss, shard_output, shard_loss |
Loading…
Reference in new issue