mirror of https://github.com/hpcaitech/ColossalAI
[gemini] hotfix NaN loss while using Gemini + tensor_parallel (#5150)
* fix aaa fix fix fix * fix * fix * test ci * fix ci fixpull/5177/head
parent
b397104438
commit
21aa5de00b
|
@ -1,9 +1,11 @@
|
|||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
@ -11,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_default_group
|
|||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.utils import (
|
||||
|
@ -448,6 +451,57 @@ class GeminiPlugin(DPPluginBase):
|
|||
|
||||
def supported_devices(self) -> List[str]:
|
||||
return ["cuda", "npu"]
|
||||
|
||||
def prepare_dataloader(
|
||||
self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs
|
||||
):
|
||||
r"""
|
||||
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
||||
`torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`.
|
||||
|
||||
|
||||
Args:
|
||||
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
||||
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
||||
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
||||
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
||||
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
||||
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
||||
the batch size, then the last batch will be smaller, defaults to False.
|
||||
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
||||
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
||||
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
||||
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
||||
|
||||
Returns:
|
||||
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
||||
"""
|
||||
_kwargs = kwargs.copy()
|
||||
zero_world_size = self.pg_mesh.size(ZERO_AXIS)
|
||||
extra_dp_world_size = self.pg_mesh.size(DP_AXIS)
|
||||
zero_rank = self.pg_mesh.coordinate(ZERO_AXIS)
|
||||
extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS)
|
||||
sampler = DistributedSampler(
|
||||
dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle
|
||||
)
|
||||
|
||||
# Deterministic dataloader
|
||||
def seed_worker(worker_id):
|
||||
worker_seed = seed
|
||||
np.random.seed(worker_seed)
|
||||
torch.manual_seed(worker_seed)
|
||||
random.seed(worker_seed)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
sampler=sampler,
|
||||
worker_init_fn=seed_worker,
|
||||
drop_last=drop_last,
|
||||
pin_memory=pin_memory,
|
||||
num_workers=num_workers,
|
||||
**_kwargs,
|
||||
)
|
||||
|
||||
def configure(
|
||||
self,
|
||||
|
|
|
@ -72,6 +72,7 @@ def main():
|
|||
parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini")
|
||||
parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini")
|
||||
parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size")
|
||||
parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini")
|
||||
parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size")
|
||||
parser.add_argument("--mbs", type=int, default=1)
|
||||
parser.add_argument("--zero", type=int, default=0)
|
||||
|
@ -93,9 +94,11 @@ def main():
|
|||
shard_param_frac=args.shard_param_frac,
|
||||
offload_optim_frac=args.offload_optim_frac,
|
||||
offload_param_frac=args.offload_param_frac,
|
||||
tp_size=args.tp,
|
||||
extra_dp_size=args.extra_dp,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio)
|
||||
plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
plugin = TorchFSDPPlugin(
|
||||
|
|
|
@ -61,7 +61,7 @@ loss_fn = lambda x: x.loss
|
|||
|
||||
config = transformers.GPTJConfig(
|
||||
n_layer=2,
|
||||
n_head=16,
|
||||
n_head=4,
|
||||
vocab_size=50258,
|
||||
attn_pdrop=0,
|
||||
embd_pdrop=0,
|
||||
|
|
Loading…
Reference in New Issue