From abf4c27f6adc4b65914744a23ba23c4e60b2a722 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Sat, 12 Nov 2022 23:12:18 +0800 Subject: [PATCH] [tutorial] removed huggingface model warning (#1925) --- examples/tutorial/opt/opt/run_clm.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/tutorial/opt/opt/run_clm.py b/examples/tutorial/opt/opt/run_clm.py index 00e05459a..2b96642ae 100755 --- a/examples/tutorial/opt/opt/run_clm.py +++ b/examples/tutorial/opt/opt/run_clm.py @@ -30,24 +30,13 @@ from itertools import chain import datasets import torch import torch.distributed as dist +import transformers from accelerate.utils import set_seed from context import barrier_context from datasets import load_dataset from packaging import version from torch.utils.data import DataLoader from tqdm.auto import tqdm - -import colossalai -import transformers -from colossalai.context import ParallelMode -from colossalai.core import global_context as gpc -from colossalai.logging import disable_existing_loggers, get_dist_logger -from colossalai.nn.optimizer import HybridAdam -from colossalai.nn.parallel import ZeroDDP -from colossalai.tensor import ProcessGroup -from colossalai.utils import get_current_device, get_dataloader -from colossalai.utils.model.colo_init_context import ColoInitContext -from colossalai.zero import ZeroOptimizer from transformers import ( CONFIG_MAPPING, MODEL_MAPPING, @@ -61,6 +50,17 @@ from transformers import ( ) from transformers.utils.versions import require_version +import colossalai +from colossalai.context import ParallelMode +from colossalai.core import global_context as gpc +from colossalai.logging import disable_existing_loggers, get_dist_logger +from colossalai.nn.optimizer import HybridAdam +from colossalai.nn.parallel import ZeroDDP +from colossalai.tensor import ProcessGroup +from colossalai.utils import get_current_device, get_dataloader +from colossalai.utils.model.colo_init_context import ColoInitContext +from colossalai.zero import ZeroOptimizer + require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) @@ -544,7 +544,7 @@ def main(): model.train() for step, batch in enumerate(train_dataloader): batch = {k: v.cuda() for k, v in batch.items()} - outputs = model(**batch) + outputs = model(use_cache=False, **batch) loss = outputs['loss'] optimizer.backward(loss)