[tutorial] removed huggingface model warning (#1925)

pull/1927/head
Frank Lee 2 years ago committed by GitHub
parent d43a671ad6
commit abf4c27f6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -30,24 +30,13 @@ from itertools import chain
import datasets
import torch
import torch.distributed as dist
import transformers
from accelerate.utils import set_seed
from context import barrier_context
from datasets import load_dataset
from packaging import version
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import (
CONFIG_MAPPING,
MODEL_MAPPING,
@ -61,6 +50,17 @@ from transformers import (
)
from transformers.utils.versions import require_version
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
@ -544,7 +544,7 @@ def main():
model.train()
for step, batch in enumerate(train_dataloader):
batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch)
outputs = model(use_cache=False, **batch)
loss = outputs['loss']
optimizer.backward(loss)

Loading…
Cancel
Save