|
|
|
@ -30,24 +30,13 @@ from itertools import chain
|
|
|
|
|
import datasets |
|
|
|
|
import torch |
|
|
|
|
import torch.distributed as dist |
|
|
|
|
import transformers |
|
|
|
|
from accelerate.utils import set_seed |
|
|
|
|
from context import barrier_context |
|
|
|
|
from datasets import load_dataset |
|
|
|
|
from packaging import version |
|
|
|
|
from torch.utils.data import DataLoader |
|
|
|
|
from tqdm.auto import tqdm |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
import transformers |
|
|
|
|
from colossalai.context import ParallelMode |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
|
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
|
|
from colossalai.nn.parallel import ZeroDDP |
|
|
|
|
from colossalai.tensor import ProcessGroup |
|
|
|
|
from colossalai.utils import get_current_device, get_dataloader |
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
|
from colossalai.zero import ZeroOptimizer |
|
|
|
|
from transformers import ( |
|
|
|
|
CONFIG_MAPPING, |
|
|
|
|
MODEL_MAPPING, |
|
|
|
@ -61,6 +50,17 @@ from transformers import (
|
|
|
|
|
) |
|
|
|
|
from transformers.utils.versions import require_version |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.context import ParallelMode |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
|
|
|
from colossalai.nn.optimizer import HybridAdam |
|
|
|
|
from colossalai.nn.parallel import ZeroDDP |
|
|
|
|
from colossalai.tensor import ProcessGroup |
|
|
|
|
from colossalai.utils import get_current_device, get_dataloader |
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
|
from colossalai.zero import ZeroOptimizer |
|
|
|
|
|
|
|
|
|
require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/language-modeling/requirements.txt") |
|
|
|
|
|
|
|
|
|
MODEL_CONFIG_CLASSES = list(MODEL_MAPPING.keys()) |
|
|
|
@ -544,7 +544,7 @@ def main():
|
|
|
|
|
model.train() |
|
|
|
|
for step, batch in enumerate(train_dataloader): |
|
|
|
|
batch = {k: v.cuda() for k, v in batch.items()} |
|
|
|
|
outputs = model(**batch) |
|
|
|
|
outputs = model(use_cache=False, **batch) |
|
|
|
|
loss = outputs['loss'] |
|
|
|
|
optimizer.backward(loss) |
|
|
|
|
|
|
|
|
|