|
|
|
@ -30,24 +30,13 @@ from itertools import chain
|
|
|
|
|
import datasets
|
|
|
|
|
import torch
|
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
import transformers
|
|
|
|
|
from accelerate.utils import set_seed
|
|
|
|
|
from context import barrier_context
|
|
|
|
|
from datasets import load_dataset
|
|
|
|
|
from packaging import version
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from tqdm.auto import tqdm
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
|
import transformers
|
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
from colossalai.nn.parallel import ZeroDDP
|
|
|
|
|
from colossalai.tensor import ProcessGroup
|
|
|
|
|
from colossalai.utils import get_current_device, get_dataloader
|
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
|
|
from transformers import (
|
|
|
|
|
CONFIG_MAPPING,
|
|
|
|
|
MODEL_MAPPING,
|
|
|
|
@ -61,6 +50,17 @@ from transformers import (
|
|
|
|
|
)
|
|
|
|
|
from transformers.utils.versions import require_version
|
|
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
|
from colossalai.context import ParallelMode
|
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger
|
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
|
|
|
|
from colossalai.nn.parallel import ZeroDDP
|
|
|
|
|
from colossalai.tensor import ProcessGroup
|
|
|
|
|
from colossalai.utils import get_current_device, get_dataloader
|
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
|
|
|
from colossalai.zero import ZeroOptimizer
|
|
|
|
|
|
|
|
|
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt")
|
|
|
|
|
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys())
|
|
|
|
@ -544,7 +544,7 @@ def main():
|
|
|
|
|
model.train()
|
|
|
|
|
for step, batch in enumerate(train_dataloader):
|
|
|
|
|
batch = {k: v.cuda() for k, v in batch.items()}
|
|
|
|
|
outputs = model(**batch)
|
|
|
|
|
outputs = model(use_cache=False, **batch)
|
|
|
|
|
loss = outputs['loss']
|
|
|
|
|
optimizer.backward(loss)
|
|
|
|
|
|
|
|
|
|