mirror of https://github.com/hpcaitech/ColossalAI
parent
45927d5527
commit
6b30dfb7ce
|
@ -21,7 +21,7 @@ class DistCrossEntropy(Function):
|
||||||
and can be rewrite as:
|
and can be rewrite as:
|
||||||
loss = log(sum(exp(x[i])) - x[class]
|
loss = log(sum(exp(x[i])) - x[class]
|
||||||
|
|
||||||
To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i]
|
To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i]
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
||||||
|
|
|
@ -19,6 +19,20 @@ def build_policies():
|
||||||
|
|
||||||
from .bert import BertForSequenceClassificationPolicy
|
from .bert import BertForSequenceClassificationPolicy
|
||||||
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaModel
|
||||||
|
|
||||||
|
from .llama import LlamaPolicy
|
||||||
|
auto_policy_dict[LlamaModel] = LlamaPolicy
|
||||||
|
|
||||||
|
from transformers import LlamaForSequenceClassification
|
||||||
|
|
||||||
|
from .llama import LlamaForSequenceClassificationPolicy
|
||||||
|
auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
|
||||||
|
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
from .llama import LlamaForCausalLMPolicy
|
||||||
|
auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
|
||||||
|
|
||||||
from transformers import GPT2Model
|
from transformers import GPT2Model
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, Dict, List, Tuple, Type
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||||
|
|
||||||
|
import colossalai.shardformer.layer.layers as col_nn
|
||||||
|
|
||||||
|
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaPolicy(Policy):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||||
|
return {
|
||||||
|
LlamaDecoderLayer:
|
||||||
|
Argument(attr_dict={
|
||||||
|
"self_attn.hidden_size": config.hidden_size // world_size,
|
||||||
|
"self_attn.num_heads": config.num_attention_heads // world_size,
|
||||||
|
},
|
||||||
|
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]),
|
||||||
|
LlamaModel:
|
||||||
|
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings])
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_layer() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
suffix="self_attn.q_proj",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
suffix="self_attn.k_proj",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
suffix="self_attn.v_proj",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Row_Layer(
|
||||||
|
suffix="self_attn.o_proj",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mlp_layer() -> List:
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
suffix="mlp.gate_proj",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
gather_output=True,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
suffix="mlp.up_proj",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
gather_output=True,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
suffix="mlp.down_proj",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
gather_output=True,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def embeddings() -> List:
|
||||||
|
return [Col_Layer(
|
||||||
|
suffix="embed_tokens",
|
||||||
|
weight="weight",
|
||||||
|
replace_layer=col_nn.VocabParallelEmbedding1D,
|
||||||
|
)]
|
||||||
|
|
||||||
|
from transformers import LlamaForCausalLM
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForCausalLMPolicy(LlamaPolicy):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def argument(config, world_size):
|
||||||
|
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
||||||
|
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])}
|
||||||
|
argument.update(llamapolicy)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def lm_head() -> List:
|
||||||
|
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import LlamaForSequenceClassification
|
||||||
|
|
||||||
|
|
||||||
|
class LlamaForSequenceClassificationPolicy(LlamaPolicy):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def argument(config, world_size):
|
||||||
|
llamapolicy = LlamaPolicy.argument_policy(config, world_size)
|
||||||
|
argument = {
|
||||||
|
LlamaForSequenceClassification:
|
||||||
|
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score])
|
||||||
|
}
|
||||||
|
argument.update(llamapolicy)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def score() -> List:
|
||||||
|
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)]
|
|
@ -0,0 +1,106 @@
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||||
|
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),)
|
||||||
|
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(rank, world_size):
|
||||||
|
cfg = LlamaConfig(num_hidden_layers=16)
|
||||||
|
org_model = LlamaForCausalLM(cfg)
|
||||||
|
|
||||||
|
shardconfig = ShardConfig(
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
gather_output=True,
|
||||||
|
)
|
||||||
|
org_model = org_model.to('cuda')
|
||||||
|
|
||||||
|
org_model_forshard = copy.deepcopy(org_model)
|
||||||
|
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda')
|
||||||
|
|
||||||
|
return org_model, sharded_model
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward(org_model, sharded_model):
|
||||||
|
input = 'Hello, my dog is cute'
|
||||||
|
inputs = tokenizer(input, return_tensors='pt').to('cuda')
|
||||||
|
del inputs["token_type_ids"]
|
||||||
|
del inputs["attention_mask"]
|
||||||
|
#orgin model
|
||||||
|
org_model.eval()
|
||||||
|
org_out = org_model(**inputs)
|
||||||
|
|
||||||
|
#shard model
|
||||||
|
sharded_model.eval()
|
||||||
|
shard_out = sharded_model(**inputs)
|
||||||
|
|
||||||
|
assert torch.allclose(
|
||||||
|
org_out[0], shard_out[0],
|
||||||
|
atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_backward(org_model, sharded_model):
|
||||||
|
# prepare input
|
||||||
|
input = 'Hello, my dog is cute'
|
||||||
|
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||||
|
del tokenized_input["token_type_ids"]
|
||||||
|
del tokenized_input["attention_mask"]
|
||||||
|
labels = tokenized_input['input_ids'].clone()
|
||||||
|
labels[labels == tokenizer.pad_token_id] = -100
|
||||||
|
tokenized_input['labels'] = labels
|
||||||
|
|
||||||
|
#orgin model
|
||||||
|
org_model.train()
|
||||||
|
org_out = org_model(**tokenized_input)
|
||||||
|
org_loss = org_out.loss
|
||||||
|
org_loss.backward()
|
||||||
|
org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
#shard model
|
||||||
|
sharded_model.train()
|
||||||
|
shard_out = sharded_model(**tokenized_input)
|
||||||
|
shard_loss = shard_out.loss
|
||||||
|
shard_loss.backward()
|
||||||
|
shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_llama(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
org_model, sharded_model = build_model(rank, world_size)
|
||||||
|
check_forward(org_model, sharded_model)
|
||||||
|
check_backward(org_model, sharded_model)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_llama():
|
||||||
|
spawn(check_llama, 4)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_llama()
|
Loading…
Reference in New Issue