mirror of https://github.com/hpcaitech/ColossalAI
parent
6b30dfb7ce
commit
c1c672d0f0
Binary file not shown.
|
@ -770,6 +770,7 @@ class Embedding1D(ParallelLayer):
|
||||||
embedding_dim: int,
|
embedding_dim: int,
|
||||||
padding_idx: int = None,
|
padding_idx: int = None,
|
||||||
dtype: torch.dtype = None,
|
dtype: torch.dtype = None,
|
||||||
|
gather_output: bool = True,
|
||||||
weight_initializer: Callable = init.normal_(),
|
weight_initializer: Callable = init.normal_(),
|
||||||
*args,
|
*args,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
@ -782,6 +783,7 @@ class Embedding1D(ParallelLayer):
|
||||||
self.padding_idx = padding_idx
|
self.padding_idx = padding_idx
|
||||||
self.embed_args = args
|
self.embed_args = args
|
||||||
self.embed_kwargs = kwargs
|
self.embed_kwargs = kwargs
|
||||||
|
self.gather_output = gather_output
|
||||||
|
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
torch.empty((num_embeddings, embed_dim_per_partition), device=get_current_device(), dtype=dtype))
|
||||||
|
@ -832,8 +834,10 @@ class Embedding1D(ParallelLayer):
|
||||||
def forward(self, input_: Tensor) -> Tensor:
|
def forward(self, input_: Tensor) -> Tensor:
|
||||||
|
|
||||||
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
|
||||||
|
if self.gather_output:
|
||||||
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
output = gather_forward_split_backward(output_parallel, ParallelMode.PARALLEL_1D, dim=-1)
|
||||||
|
else:
|
||||||
|
output = output_parallel
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -43,6 +43,15 @@ def build_policies():
|
||||||
|
|
||||||
from .gpt2 import GPT2LMHeadModelPolicy
|
from .gpt2 import GPT2LMHeadModelPolicy
|
||||||
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy
|
||||||
|
|
||||||
|
from .t5 import T5ForConditionalGenerationPolicy, T5EncoderModelPolicy, T5ModelPolicy
|
||||||
|
from transformers import T5ForConditionalGeneration, T5EncoderModel, T5Model
|
||||||
|
t5 = {
|
||||||
|
T5ForConditionalGeneration: T5ForConditionalGenerationPolicy,
|
||||||
|
T5EncoderModel: T5EncoderModelPolicy,
|
||||||
|
T5Model: T5ModelPolicy,
|
||||||
|
}
|
||||||
|
auto_policy_dict.update(t5)
|
||||||
|
|
||||||
return auto_policy_dict
|
return auto_policy_dict
|
||||||
|
|
||||||
|
|
|
@ -80,6 +80,18 @@ class Dropout_Layer(Layer):
|
||||||
p: str = None
|
p: str = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Embedding_Layer(Layer):
|
||||||
|
r"""
|
||||||
|
Class for col shard layer in tensor parrallel
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (str): The weight suffix of the layer
|
||||||
|
"""
|
||||||
|
weight: str = None
|
||||||
|
gather_output: bool = True
|
||||||
|
|
||||||
|
|
||||||
class Policy():
|
class Policy():
|
||||||
r"""
|
r"""
|
||||||
The base class for all the policies
|
The base class for all the policies
|
||||||
|
|
|
@ -0,0 +1,159 @@
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn import Embedding
|
||||||
|
from transformers.models.t5.modeling_t5 import (
|
||||||
|
T5Attention,
|
||||||
|
T5Block,
|
||||||
|
T5DenseActDense,
|
||||||
|
T5DenseGatedActDense,
|
||||||
|
T5LayerCrossAttention,
|
||||||
|
T5LayerFF,
|
||||||
|
T5LayerSelfAttention,
|
||||||
|
T5Model,
|
||||||
|
T5Stack,
|
||||||
|
)
|
||||||
|
|
||||||
|
import colossalai.shardformer.layer.layers as col_nn
|
||||||
|
|
||||||
|
from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer
|
||||||
|
|
||||||
|
|
||||||
|
class T5ModelPolicy(Policy):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]:
|
||||||
|
print('config heads', config.num_heads)
|
||||||
|
return {
|
||||||
|
T5Stack:
|
||||||
|
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]),
|
||||||
|
T5Block:
|
||||||
|
Argument(attr_dict={}, param_funcs=[]),
|
||||||
|
T5LayerSelfAttention:
|
||||||
|
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||||
|
T5LayerCrossAttention:
|
||||||
|
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||||
|
T5Attention:
|
||||||
|
Argument(attr_dict={
|
||||||
|
"d_model": config.d_model // world_size,
|
||||||
|
"n_heads": config.num_heads // world_size,
|
||||||
|
"inner_dim": config.num_heads * config.d_kv // world_size,
|
||||||
|
},
|
||||||
|
param_funcs=[T5ModelPolicy.attn_layer]),
|
||||||
|
T5LayerFF:
|
||||||
|
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]),
|
||||||
|
T5DenseGatedActDense:
|
||||||
|
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]),
|
||||||
|
T5DenseActDense:
|
||||||
|
Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]),
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dense_gated_layer():
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
suffix="wi_0",
|
||||||
|
weight="weight",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Row_Layer(
|
||||||
|
suffix="wi_1",
|
||||||
|
weight="weight",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
),
|
||||||
|
Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dense_act_layer():
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
suffix="wi",
|
||||||
|
weight="weight",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Row_Layer(
|
||||||
|
suffix="wo",
|
||||||
|
weight="weight",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def attn_layer():
|
||||||
|
return [
|
||||||
|
Col_Layer(
|
||||||
|
suffix="q",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
suffix="k",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Col_Layer(
|
||||||
|
suffix="v",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
),
|
||||||
|
Row_Layer(
|
||||||
|
suffix="o",
|
||||||
|
weight="weight",
|
||||||
|
bias="bias",
|
||||||
|
replace_layer=col_nn.Linear1D_Row,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dropout():
|
||||||
|
return [Dropout_Layer(
|
||||||
|
suffix="dropout",
|
||||||
|
p="p",
|
||||||
|
replace_layer=col_nn.Dropout1D,
|
||||||
|
)]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def embedding():
|
||||||
|
return [
|
||||||
|
Embedding_Layer(
|
||||||
|
suffix="block[0].layer[0].SelfAttention.relative_attention_bias",
|
||||||
|
weight="weight",
|
||||||
|
replace_layer=col_nn.Embedding1D,
|
||||||
|
gather_output=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import T5ForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
|
class T5ForConditionalGenerationPolicy(T5ModelPolicy):
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def argument_policy(config, world_size):
|
||||||
|
base_argument = T5ModelPolicy.argument_policy(config, world_size)
|
||||||
|
argument = {
|
||||||
|
T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head])
|
||||||
|
}
|
||||||
|
argument.update(base_argument)
|
||||||
|
return argument
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def lm_head():
|
||||||
|
return [Col_Layer(
|
||||||
|
suffix="lm_head",
|
||||||
|
weight="weight",
|
||||||
|
replace_layer=col_nn.Linear1D_Col,
|
||||||
|
gather_output=True,
|
||||||
|
)]
|
||||||
|
|
||||||
|
|
||||||
|
from transformers import T5EncoderModel
|
||||||
|
|
||||||
|
|
||||||
|
class T5EncoderModelPolicy(T5ModelPolicy):
|
||||||
|
pass
|
|
@ -5,7 +5,7 @@ import torch.nn as nn
|
||||||
from transformers.pytorch_utils import Conv1D
|
from transformers.pytorch_utils import Conv1D
|
||||||
|
|
||||||
from ..policies.autopolicy import get_autopolicy
|
from ..policies.autopolicy import get_autopolicy
|
||||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer
|
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Policy, Row_Layer, Embedding_Layer
|
||||||
from ..utils.utils import getattr_, hasattr_, setattr_
|
from ..utils.utils import getattr_, hasattr_, setattr_
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
from .slicer import Slicer
|
from .slicer import Slicer
|
||||||
|
@ -155,11 +155,11 @@ class ModelSharder(object):
|
||||||
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
|
assert suffix_layer is not None or ignore, f"Layer {org_layer.__class__.__qualname__} has no attribute {suffix}"
|
||||||
if suffix_layer is None and ignore:
|
if suffix_layer is None and ignore:
|
||||||
continue
|
continue
|
||||||
if isinstance(policy_layer, (Col_Layer, Row_Layer)):
|
if isinstance(policy_layer, (Col_Layer, Row_Layer, Embedding_Layer)):
|
||||||
weight = None
|
weight = None
|
||||||
bias = None
|
bias = None
|
||||||
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
|
weight_attr = suffix + '.' + policy_layer.weight if policy_layer.weight is not None else None
|
||||||
bias_attr = suffix + '.' + policy_layer.bias if policy_layer.bias is not None else None
|
bias_attr = suffix + '.' + policy_layer.bias if hasattr(policy_layer, 'bias') and policy_layer.bias is not None else None
|
||||||
|
|
||||||
if weight_attr is not None:
|
if weight_attr is not None:
|
||||||
if hasattr_(org_layer, weight_attr):
|
if hasattr_(org_layer, weight_attr):
|
||||||
|
@ -189,6 +189,11 @@ class ModelSharder(object):
|
||||||
weight.shape[1],
|
weight.shape[1],
|
||||||
bias=False if bias is None else True,
|
bias=False if bias is None else True,
|
||||||
gather_output=gather_output)
|
gather_output=gather_output)
|
||||||
|
elif replace_layer_cls.__name__ == "Embedding1D":
|
||||||
|
gather_output = policy_layer.gather_output
|
||||||
|
replace_layer = replace_layer_cls(weight.shape[0],
|
||||||
|
weight.shape[1],
|
||||||
|
gather_output=gather_output)
|
||||||
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
|
elif replace_layer_cls.__name__ == "VocabParallelEmbedding1D":
|
||||||
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
replace_layer = replace_layer_cls(weight.shape[0], weight.shape[1],
|
||||||
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
|
getattr_(org_layer, f"{suffix}.padding_idx", ignore=True))
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer
|
from ..policies.basepolicy import Col_Layer, Dropout_Layer, Layer, Row_Layer, Embedding_Layer
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
|
|
||||||
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
|
dim_mapping = {Col_Layer: 0, Row_Layer: 1, Embedding_Layer: 1}
|
||||||
|
|
||||||
|
|
||||||
class Slicer():
|
class Slicer():
|
||||||
|
@ -43,6 +43,8 @@ class Slicer():
|
||||||
bias = self.slice_tensor(bias, 0, True, n_cast)
|
bias = self.slice_tensor(bias, 0, True, n_cast)
|
||||||
elif policy_layer_cls == Row_Layer:
|
elif policy_layer_cls == Row_Layer:
|
||||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||||
|
elif policy_layer_cls == Embedding_Layer:
|
||||||
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
raise NotImplementedError(f"The policy layer class {policy_layer_cls} is not supported")
|
||||||
if reversed:
|
if reversed:
|
||||||
|
|
|
@ -1,3 +1,22 @@
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
def get_obj_list_element(obj, a):
|
||||||
|
re_pattern = r'\[\d+\]'
|
||||||
|
prog = re.compile(re_pattern)
|
||||||
|
result = prog.search(a)
|
||||||
|
if result:
|
||||||
|
matched_brackets = result.group()
|
||||||
|
matched_index = matched_brackets.replace('[', '')
|
||||||
|
matched_index = matched_index.replace(']', '')
|
||||||
|
a_ = a.replace(matched_brackets, '')
|
||||||
|
container_obj = getattr(obj, a_)
|
||||||
|
obj = container_obj[int(matched_index)]
|
||||||
|
else:
|
||||||
|
obj = getattr(obj, a)
|
||||||
|
return obj
|
||||||
|
|
||||||
|
|
||||||
def hasattr_(obj, attr: str):
|
def hasattr_(obj, attr: str):
|
||||||
r"""
|
r"""
|
||||||
Check whether the object has the multi sublevel attr
|
Check whether the object has the multi sublevel attr
|
||||||
|
@ -9,7 +28,7 @@ def hasattr_(obj, attr: str):
|
||||||
attrs = attr.split('.')
|
attrs = attr.split('.')
|
||||||
for a in attrs:
|
for a in attrs:
|
||||||
try:
|
try:
|
||||||
obj = getattr(obj, a)
|
obj = get_obj_list_element(obj, a)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
@ -29,7 +48,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
|
||||||
attrs = attr.split('.')
|
attrs = attr.split('.')
|
||||||
for a in attrs[:-1]:
|
for a in attrs[:-1]:
|
||||||
try:
|
try:
|
||||||
obj = getattr(obj, a)
|
obj = get_obj_list_element(obj, a)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if ignore:
|
if ignore:
|
||||||
return
|
return
|
||||||
|
@ -50,7 +69,7 @@ def getattr_(obj, attr: str, ignore: bool = False):
|
||||||
attrs = attr.split('.')
|
attrs = attr.split('.')
|
||||||
for a in attrs:
|
for a in attrs:
|
||||||
try:
|
try:
|
||||||
obj = getattr(obj, a)
|
obj = get_obj_list_element(obj, a)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if ignore:
|
if ignore:
|
||||||
return None
|
return None
|
||||||
|
|
|
@ -15,3 +15,4 @@ einops
|
||||||
triton==2.0.0.dev20221202
|
triton==2.0.0.dev20221202
|
||||||
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
|
git+https://github.com/HazyResearch/flash-attention.git@c422fee3776eb3ea24e011ef641fd5fbeb212623#egg=flash_attn
|
||||||
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
requests==2.27.1 # downgrade to avoid huggingface error https://github.com/huggingface/transformers/issues/17611
|
||||||
|
SentencePiece
|
||||||
|
|
|
@ -0,0 +1,99 @@
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.shard import ShardConfig, shard_model
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
||||||
|
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||||
|
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(rank, world_size):
|
||||||
|
config = T5Config.from_pretrained("t5-small")
|
||||||
|
config.dropout_rate = 0
|
||||||
|
org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda')
|
||||||
|
|
||||||
|
shardconfig = ShardConfig(
|
||||||
|
rank=rank,
|
||||||
|
world_size=world_size,
|
||||||
|
gather_output=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
org_model_for_shard = copy.deepcopy(org_model)
|
||||||
|
|
||||||
|
sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda')
|
||||||
|
|
||||||
|
return org_model, sharded_model
|
||||||
|
|
||||||
|
|
||||||
|
def check_forward(org_model, sharded_model):
|
||||||
|
|
||||||
|
input_ids = tokenizer("translate English to German: The house is wonderful.",
|
||||||
|
return_tensors="pt").input_ids.to('cuda')
|
||||||
|
#orgin model
|
||||||
|
org_model.eval()
|
||||||
|
org_output = org_model.generate(input_ids)
|
||||||
|
|
||||||
|
#shard model
|
||||||
|
sharded_model.eval()
|
||||||
|
shard_output = sharded_model.generate(input_ids)
|
||||||
|
assert torch.allclose(
|
||||||
|
org_output[0], shard_output[0],
|
||||||
|
atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_backward(org_model, sharded_model):
|
||||||
|
# prepare input
|
||||||
|
input_ids = tokenizer("translate English to German: The house is wonderful.",
|
||||||
|
return_tensors="pt").input_ids.to('cuda')
|
||||||
|
labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda')
|
||||||
|
|
||||||
|
#orgin model
|
||||||
|
org_model.train()
|
||||||
|
org_loss = org_model(input_ids=input_ids, labels=labels).loss
|
||||||
|
org_loss.backward()
|
||||||
|
org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||||
|
|
||||||
|
#shard model
|
||||||
|
sharded_model.train()
|
||||||
|
shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss
|
||||||
|
shard_loss.backward()
|
||||||
|
shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad
|
||||||
|
|
||||||
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
|
||||||
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
||||||
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
||||||
|
|
||||||
|
assert torch.allclose(org_loss, shard_loss,
|
||||||
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
||||||
|
assert torch.allclose(org_grad, all_shard_grad,
|
||||||
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
||||||
|
|
||||||
|
|
||||||
|
def check_t5(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
org_model, sharded_model = build_model(rank, world_size)
|
||||||
|
check_forward(org_model, sharded_model)
|
||||||
|
check_backward(org_model, sharded_model)
|
||||||
|
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_t5():
|
||||||
|
spawn(check_t5, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_t5()
|
Loading…
Reference in New Issue