[tutorial] removed huggingface model warning (#1925)

pull/1927/head
Frank Lee 2022-11-12 23:12:18 +08:00 committed by GitHub
parent d43a671ad6
commit abf4c27f6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 13 deletions

View File

@ -30,24 +30,13 @@ from itertools import chain
import datasets import datasets
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import transformers
from accelerate.utils import set_seed from accelerate.utils import set_seed
from context import barrier_context from context import barrier_context
from datasets import load_dataset from datasets import load_dataset
from packaging import version from packaging import version
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from tqdm.auto import tqdm from tqdm.auto import tqdm
import colossalai
import transformers
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
from transformers import ( from transformers import (
CONFIG_MAPPING, CONFIG_MAPPING,
MODEL_MAPPING, MODEL_MAPPING,
@ -61,6 +50,17 @@ from transformers import (
) )
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
import colossalai
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.logging import disable_existing_loggers, get_dist_logger
from colossalai.nn.optimizer import HybridAdam
from colossalai.nn.parallel import ZeroDDP
from colossalai.tensor import ProcessGroup
from colossalai.utils import get_current_device, get_dataloader
from colossalai.utils.model.colo_init_context import ColoInitContext
from colossalai.zero import ZeroOptimizer
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
@ -544,7 +544,7 @@ def main():
model.train() model.train()
for step, batch in enumerate(train_dataloader): for step, batch in enumerate(train_dataloader):
batch = {k: v.cuda() for k, v in batch.items()} batch = {k: v.cuda() for k, v in batch.items()}
outputs = model(**batch) outputs = model(use_cache=False, **batch)
loss = outputs['loss'] loss = outputs['loss']
optimizer.backward(loss) optimizer.backward(loss)