mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] adapted T5 and LLaMa test to use kit (#4049)
* [shardformer] adapted T5 and LLaMa test to use kit * polish codepull/4157/head
parent
4021b9a8a2
commit
58df720570
|
@ -65,13 +65,14 @@ class Embedding1D(ParallelModule):
|
|||
dtype: torch.dtype = None,
|
||||
device: torch.device = None,
|
||||
process_group: ProcessGroup = None,
|
||||
gather_output: bool = True,
|
||||
weight_initializer: Callable = init.normal_(),
|
||||
*args,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embed_dim = embedding_dim
|
||||
self.embedding_dim = embedding_dim
|
||||
self.process_group = process_group
|
||||
self.num_partitions = dist.get_world_size(process_group)
|
||||
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
|
||||
|
@ -79,7 +80,7 @@ class Embedding1D(ParallelModule):
|
|||
self.padding_idx = padding_idx
|
||||
self.embed_args = args
|
||||
self.embed_kwargs = kwargs
|
||||
# self.gather_output = gather_output
|
||||
self.gather_output = gather_output
|
||||
|
||||
if device is None:
|
||||
device = get_current_device()
|
||||
|
@ -95,7 +96,9 @@ class Embedding1D(ParallelModule):
|
|||
|
||||
@staticmethod
|
||||
def from_native_module(module: nn.Embedding,
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
|
||||
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
|
||||
*args,
|
||||
**kwargs) -> "Embedding1D":
|
||||
r"""
|
||||
Build a 1D parallelized Embedding from a native nn.Embedding module.
|
||||
"""
|
||||
|
@ -123,7 +126,9 @@ class Embedding1D(ParallelModule):
|
|||
max_norm=max_norm,
|
||||
norm_type=norm_type,
|
||||
scale_grad_by_freq=scale_grad_by_freq,
|
||||
sparse=sparse)
|
||||
sparse=sparse,
|
||||
*args,
|
||||
**kwargs)
|
||||
|
||||
# copy the weight
|
||||
with torch.no_grad():
|
||||
|
@ -133,7 +138,7 @@ class Embedding1D(ParallelModule):
|
|||
return embedding
|
||||
|
||||
def reset_parameters(self, weight_initializer) -> None:
|
||||
fan_in, fan_out = self.num_embeddings, self.embed_dim
|
||||
fan_in, fan_out = self.num_embeddings, self.embedding_dim
|
||||
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
|
||||
self._fill_padding_idx_with_zero()
|
||||
|
||||
|
@ -144,6 +149,9 @@ class Embedding1D(ParallelModule):
|
|||
|
||||
def forward(self, input_: Tensor) -> Tensor:
|
||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
|
||||
return output
|
||||
if self.gather_output:
|
||||
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
|
||||
return output
|
||||
else:
|
||||
return output_parallel
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel
|
||||
|
||||
from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
|
|
@ -11,8 +11,7 @@ from transformers.models.t5.modeling_t5 import (
|
|||
T5Stack,
|
||||
)
|
||||
|
||||
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row
|
||||
|
||||
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
|
||||
|
||||
|
|
|
@ -185,7 +185,14 @@ class ModelSharder(object):
|
|||
if description.ignore_if_not_exist and native_sub_module is None:
|
||||
continue
|
||||
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||
**kwargs)
|
||||
try:
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||
**kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
|
||||
f" with {target_module.__qualname__} with the exception: {e}. "
|
||||
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
|
||||
)
|
||||
|
||||
setattr_(org_layer, suffix, replace_layer)
|
||||
|
|
|
@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
|
|||
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
|
||||
assert torch.allclose(
|
||||
out1, out2, atol=atol, rtol=rtol
|
||||
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
|
||||
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
|
||||
else:
|
||||
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"
|
||||
|
|
|
@ -28,27 +28,35 @@ class ModelZooRegistry(dict):
|
|||
model_fn: Callable,
|
||||
data_gen_fn: Callable,
|
||||
output_transform_fn: Callable,
|
||||
loss_fn: Callable = None,
|
||||
model_attribute: ModelAttribute = None):
|
||||
"""
|
||||
Register a model and data generation function.
|
||||
|
||||
Examples:
|
||||
>>> # Register
|
||||
>>> model_zoo = ModelZooRegistry()
|
||||
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
|
||||
>>> # Run the model
|
||||
>>> data = resnet18_data_gen() # do not input any argument
|
||||
>>> model = resnet18() # do not input any argument
|
||||
>>> out = model(**data)
|
||||
|
||||
```python
|
||||
# normal forward workflow
|
||||
model = resnet18()
|
||||
data = resnet18_data_gen()
|
||||
output = model(**data)
|
||||
transformed_output = output_transform_fn(output)
|
||||
loss = loss_fn(transformed_output)
|
||||
|
||||
# Register
|
||||
model_zoo = ModelZooRegistry()
|
||||
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
|
||||
```
|
||||
|
||||
Args:
|
||||
name (str): Name of the model.
|
||||
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
|
||||
output_transform_fn (callable): A function that transforms the output of the model into Dict.
|
||||
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
|
||||
model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
|
||||
data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
|
||||
output_transform_fn (Callable): A function that transforms the output of the model into Dict.
|
||||
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
|
||||
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
|
||||
"""
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
|
||||
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)
|
||||
|
||||
def get_sub_registry(self, keyword: str):
|
||||
"""
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
from .albert import *
|
||||
from .bert import *
|
||||
from .gpt import *
|
||||
from .llama import *
|
||||
from .opt import *
|
||||
from .t5 import *
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
import torch
|
||||
import transformers
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
try:
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
|
||||
HAS_LLAMA = True
|
||||
except ImportError:
|
||||
HAS_LLAMA = False
|
||||
|
||||
if HAS_LLAMA:
|
||||
# ===============================
|
||||
# Register LLaMA
|
||||
# ===============================
|
||||
|
||||
def data_gen():
|
||||
# the input ids are corresponding to the sentence
|
||||
# 'Hello, my dog is cute'
|
||||
#
|
||||
# the code is give below:
|
||||
# -----------------------------------
|
||||
# from transformers import LlamaTokenizerFast
|
||||
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
# input = 'Hello, my dog is cute'
|
||||
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
# -----------------------------------
|
||||
|
||||
input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
|
||||
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
|
||||
return dict(input_ids=input_ids, attention_mask=attention_mask)
|
||||
|
||||
# label is needed for casual lm
|
||||
def data_gen_for_casual_lm():
|
||||
data = data_gen()
|
||||
labels = data['input_ids'].clone()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
# transform the output to a dict
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
# function to get the loss
|
||||
loss_fn = lambda output: output.last_hidden_state.mean()
|
||||
loss_fn_for_casual_lm = lambda output: output.loss
|
||||
loss_fn_for_seq_classification = lambda output: output.logits.mean()
|
||||
|
||||
config = LlamaConfig(num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128,
|
||||
num_labels=16)
|
||||
|
||||
# register the following models
|
||||
# transformers.LlamaModel,
|
||||
# transformers.LlamaForCausalLM,
|
||||
# transformers.LlamaForSequenceClassification,
|
||||
model_zoo.register(name='transformers_llama',
|
||||
model_fn=lambda: transformers.LlamaModel(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_llama_for_casual_lm',
|
||||
model_fn=lambda: transformers.LlamaForCausalLM(config),
|
||||
data_gen_fn=data_gen_for_casual_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_casual_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_llama_for_sequence_classification',
|
||||
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
|
||||
data_gen_fn=data_gen,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_seq_classification,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
|
@ -6,24 +6,50 @@ from ..registry import ModelAttribute, model_zoo
|
|||
# ===============================
|
||||
# Register single-sentence T5
|
||||
# ===============================
|
||||
BATCH_SIZE = 2
|
||||
SEQ_LENGTH = 16
|
||||
|
||||
|
||||
def data_gen():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
|
||||
|
||||
# define data gen function
|
||||
def data_gen_for_encoder_only():
|
||||
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
|
||||
# Generated from following code snippet
|
||||
#
|
||||
# from transformers import T5Config, T5Tokenizer
|
||||
# config = T5Config(decoder_start_token_id=0)
|
||||
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
|
||||
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
|
||||
return dict(input_ids=input_ids)
|
||||
|
||||
|
||||
def data_gen_for_conditional_generation():
|
||||
# labels is generated with the following code
|
||||
#
|
||||
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
|
||||
data = data_gen_for_encoder_only()
|
||||
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
|
||||
data['labels'] = labels
|
||||
return data
|
||||
|
||||
|
||||
def data_gen_for_t5_model():
|
||||
# decoder_inputs_ids is obtained with the following code
|
||||
#
|
||||
# decoder_input_ids = model._shift_right(input_ids)
|
||||
data = data_gen_for_encoder_only()
|
||||
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
|
||||
data['decoder_input_ids'] = decoder_input_ids
|
||||
return data
|
||||
|
||||
|
||||
# output transform function
|
||||
output_transform_fn = lambda x: x
|
||||
|
||||
config = transformers.T5Config(d_model=128, num_layers=2)
|
||||
# define loss funciton
|
||||
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
|
||||
loss_fn_for_conditional_generation = lambda x: x.loss
|
||||
|
||||
# define model config
|
||||
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)
|
||||
|
||||
# register the following models
|
||||
# transformers.T5Model,
|
||||
|
@ -31,16 +57,19 @@ config = transformers.T5Config(d_model=128, num_layers=2)
|
|||
# transformers.T5EncoderModel,
|
||||
model_zoo.register(name='transformers_t5',
|
||||
model_fn=lambda: transformers.T5Model(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=data_gen_for_t5_model,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_t5_model,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_t5_for_conditional_generation',
|
||||
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
|
||||
data_gen_fn=data_gen,
|
||||
data_gen_fn=data_gen_for_conditional_generation,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_conditional_generation,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
model_zoo.register(name='transformers_t5_encoder_model',
|
||||
model_fn=lambda: transformers.T5EncoderModel(config),
|
||||
data_gen_fn=data_gen_for_encoder_only,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_encoder_only,
|
||||
model_attribute=ModelAttribute(has_control_flow=True))
|
||||
|
|
|
@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port):
|
|||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
|
||||
sub_model_zoo = model_zoo.get_sub_registry('timm')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items():
|
||||
# dlrm_interactionarch has not parameters, so skip
|
||||
if name == 'dlrm_interactionarch':
|
||||
continue
|
||||
|
|
|
@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
|
|||
passed_models = []
|
||||
failed_info = {} # (model_name, error) pair
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
# These models lead to CUDA error
|
||||
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
|
||||
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
|
||||
|
|
|
@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
|
|||
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
|
||||
skipped_models = []
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
# FIXME(ver217): fix these models
|
||||
if name in ignore_models:
|
||||
skipped_models.append(name)
|
||||
|
|
|
@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
|||
|
||||
|
||||
def check_torch_ddp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if name == 'dlrm_interactionarch':
|
||||
continue
|
||||
run_fn(model_fn, data_gen_fn, output_transform_fn)
|
||||
|
|
|
@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):
|
|||
|
||||
|
||||
def check_torch_fsdp_plugin():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
|
||||
if any(element in name for element in [
|
||||
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
|
||||
'torchvision_inception_v3'
|
||||
|
|
|
@ -47,7 +47,7 @@ def test_diffusers():
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
trace_and_compare(model_fn, data, output_transform_fn)
|
||||
torch.cuda.synchronize()
|
||||
|
@ -60,7 +60,7 @@ def test_torch_diffusers():
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('diffusers')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
model = model_fn()
|
||||
output = model(**data)
|
||||
|
|
|
@ -56,7 +56,7 @@ def test_timm_models():
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('timm')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
data = data_gen_fn()
|
||||
if attribute is not None and attribute.has_control_flow:
|
||||
meta_args = {k: v.to('meta') for k, v in data.items()}
|
||||
|
|
|
@ -16,7 +16,7 @@ def test_torchaudio_models():
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('torchaudio')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
|
||||
model = model_fn()
|
||||
trace_and_compare(model,
|
||||
data_gen_fn,
|
||||
|
|
|
@ -60,7 +60,7 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn:
|
|||
|
||||
|
||||
def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
|
||||
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
|
||||
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
|
||||
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
|
||||
ctx = LazyInitContext(tensor_cls=_MyTensor)
|
||||
|
|
|
@ -78,7 +78,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
|
|||
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
|
||||
continue
|
||||
print_rank_0(name)
|
||||
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
|
||||
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
|
||||
ctx = LazyInitContext(tensor_cls=_MyTensor)
|
||||
with ctx:
|
||||
model = model_fn()
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
import copy
|
||||
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
def build_model(world_size, model_fn):
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy)
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
# switch to train mode
|
||||
original_model.train()
|
||||
sharded_model.train()
|
||||
|
||||
# run forward
|
||||
org_output = original_model(**data)
|
||||
org_output = output_transform_fn(org_output)
|
||||
org_loss = loss_fn(org_output)
|
||||
|
||||
shard_output = sharded_model(**data)
|
||||
shard_output = output_transform_fn(shard_output)
|
||||
shard_loss = loss_fn(shard_output)
|
||||
|
||||
return org_output, org_loss, shard_output, shard_loss
|
|
@ -1,64 +1,22 @@
|
|||
import copy
|
||||
import os
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
|
||||
|
||||
def build_model(world_size, model_fn):
|
||||
# create new model
|
||||
config = LlamaConfig(num_hidden_layers=4,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
max_position_embeddings=128)
|
||||
org_model = model_fn(config).cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy)
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input = 'Hello, my dog is cute'
|
||||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
|
||||
del tokenized_input["token_type_ids"]
|
||||
del tokenized_input["attention_mask"]
|
||||
|
||||
# switch to train mode
|
||||
org_model.train()
|
||||
sharded_model.train()
|
||||
|
||||
if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)):
|
||||
org_output = org_model(**tokenized_input)
|
||||
org_loss = org_output.last_hidden_state.mean()
|
||||
shard_output = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_output.last_hidden_state.mean()
|
||||
elif isinstance(org_model, LlamaForCausalLM):
|
||||
labels = tokenized_input['input_ids'].clone()
|
||||
labels[labels == tokenizer.pad_token_id] = -100
|
||||
tokenized_input['labels'] = labels
|
||||
org_output = org_model(**tokenized_input)
|
||||
org_loss = org_output.loss
|
||||
shard_output = sharded_model(**tokenized_input)
|
||||
shard_loss = shard_output.loss
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
|
||||
# forward check
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
|
||||
|
||||
# run backward
|
||||
|
@ -66,12 +24,12 @@ def check_forward_backward(org_model, sharded_model):
|
|||
shard_loss.backward()
|
||||
|
||||
# check grad
|
||||
if isinstance(org_model, LlamaModel):
|
||||
llama_model = org_model
|
||||
shard_llama_model = sharded_model
|
||||
else:
|
||||
if hasattr(org_model, 'model'):
|
||||
llama_model = org_model.model
|
||||
shard_llama_model = sharded_model.model
|
||||
else:
|
||||
llama_model = org_model
|
||||
shard_llama_model = sharded_model
|
||||
|
||||
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
||||
|
@ -89,17 +47,11 @@ def check_llama(rank, world_size, port):
|
|||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
model_list = [
|
||||
LlamaModel,
|
||||
# LlamaForCausalLM,
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
# TODO: do not work yet
|
||||
# LlamaForSequenceClassification
|
||||
]
|
||||
|
||||
for model_fn in model_list:
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward_backward(org_model, sharded_model)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
|
|
@ -1,64 +1,20 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer.shard import ShardConfig, ShardFormer
|
||||
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
|
||||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
||||
|
||||
|
||||
def build_model(world_size, model_fn):
|
||||
config = T5Config(decoder_start_token_id=0)
|
||||
config.dropout_rate = 0
|
||||
org_model = model_fn(config=config).to('cuda')
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy)
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
def check_forward_backward(org_model, sharded_model):
|
||||
# prepare input
|
||||
input_ids = tokenizer("translate English to German: The house is wonderful.",
|
||||
return_tensors="pt").input_ids.to('cuda')
|
||||
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
|
||||
|
||||
# switch to train mode
|
||||
org_model.train()
|
||||
sharded_model.train()
|
||||
|
||||
if isinstance(org_model, T5ForConditionalGeneration):
|
||||
org_output = org_model(input_ids=input_ids, labels=labels)
|
||||
org_loss = org_output.loss
|
||||
shard_output = sharded_model(input_ids=input_ids, labels=labels)
|
||||
shard_loss = shard_output.loss
|
||||
elif isinstance(org_model, T5Model):
|
||||
decoder_input_ids = org_model._shift_right(input_ids)
|
||||
org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
org_loss = org_output.last_hidden_state.mean()
|
||||
shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
|
||||
shard_loss = shard_output.last_hidden_state.mean()
|
||||
elif isinstance(org_model, T5EncoderModel):
|
||||
org_output = org_model(input_ids=input_ids)
|
||||
org_loss = org_output.last_hidden_state.mean()
|
||||
shard_output = sharded_model(input_ids=input_ids)
|
||||
shard_loss = shard_output.last_hidden_state.mean()
|
||||
|
||||
# key is sharded, so we ignore
|
||||
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
||||
# check forward
|
||||
# the value "past_key_values" is sharded, so we ignore
|
||||
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
||||
output_transform_fn, loss_fn)
|
||||
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'])
|
||||
|
||||
# do backward
|
||||
|
@ -81,18 +37,15 @@ def check_forward_backward(org_model, sharded_model):
|
|||
|
||||
def check_t5(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
model_fn_list = [
|
||||
T5Model,
|
||||
T5ForConditionalGeneration,
|
||||
T5EncoderModel,
|
||||
]
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
for model_fn in model_fn_list:
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
check_forward_backward(org_model, sharded_model)
|
||||
torch.cuda.empty_cache()
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
Loading…
Reference in New Issue