mirror of https://github.com/hpcaitech/ColossalAI
[tutorial] removed huggingface model warning (#1925)
parent
d43a671ad6
commit
abf4c27f6a
|
@ -30,24 +30,13 @@ from itertools import chain
|
||||||
import datasets
|
import datasets
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
import transformers
|
||||||
from accelerate.utils import set_seed
|
from accelerate.utils import set_seed
|
||||||
from context import barrier_context
|
from context import barrier_context
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import colossalai
|
|
||||||
import transformers
|
|
||||||
from colossalai.context import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
|
||||||
from colossalai.nn.parallel import ZeroDDP
|
|
||||||
from colossalai.tensor import ProcessGroup
|
|
||||||
from colossalai.utils import get_current_device, get_dataloader
|
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
from colossalai.zero import ZeroOptimizer
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
CONFIG_MAPPING,
|
CONFIG_MAPPING,
|
||||||
MODEL_MAPPING,
|
MODEL_MAPPING,
|
||||||
|
@ -61,6 +50,17 @@ from transformers import (
|
||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.context import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.parallel import ZeroDDP
|
||||||
|
from colossalai.tensor import ProcessGroup
|
||||||
|
from colossalai.utils import get_current_device, get_dataloader
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from colossalai.zero import ZeroOptimizer
|
||||||
|
|
||||||
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
||||||
|
|
||||||
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
||||||
|
@ -544,7 +544,7 @@ def main():
|
||||||
model.train()
|
model.train()
|
||||||
for step, batch in enumerate(train_dataloader):
|
for step, batch in enumerate(train_dataloader):
|
||||||
batch = {k: v.cuda() for k, v in batch.items()}
|
batch = {k: v.cuda() for k, v in batch.items()}
|
||||||
outputs = model(**batch)
|
outputs = model(use_cache=False, **batch)
|
||||||
loss = outputs['loss']
|
loss = outputs['loss']
|
||||||
optimizer.backward(loss)
|
optimizer.backward(loss)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue