ColossalAI/colossalai/shardformer/test/test.py

38 lines
1.2 KiB
Python

from transformers import AutoTokenizer
from transformers import BertForMaskedLM
import colossalai
from colossalai.shardformer.shard.shardmodel import ShardModel
from colossalai.utils import get_current_device, print_rank_0
from colossalai.logging import get_dist_logger
from colossalai.shardformer.shard.shardconfig import ShardConfig
import inspect
import argparse
import torch.nn as nn
import os
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
def get_args():
parser = colossalai.get_default_parser()
return parser.parse_args()
def inference(model: nn.Module):
# print(model)
token = "Hello, my dog is cute"
inputs = tokenizer(token, return_tensors="pt")
inputs.to("cuda")
model.to("cuda")
outputs = model(**inputs)
print(outputs)
if __name__ == "__main__":
args = get_args()
colossalai.launch_from_torch(config=args.config)
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
shard_config = ShardConfig(
rank = int(str(get_current_device()).split(':')[-1]),
world_size= int(os.environ['WORLD_SIZE']),
)
shardmodel = ShardModel(model, shard_config)
inference(shardmodel.model)