From 21aa5de00b6138c019fae5f58024f2aff6f97a3a Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Fri, 8 Dec 2023 11:10:51 +0800 Subject: [PATCH] [gemini] hotfix NaN loss while using Gemini + tensor_parallel (#5150) * fix aaa fix fix fix * fix * fix * test ci * fix ci fix --- colossalai/booster/plugin/gemini_plugin.py | 54 ++++++++++++++++++++++ examples/language/llama2/benchmark.py | 5 +- tests/kit/model_zoo/transformers/gptj.py | 2 +- 3 files changed, 59 insertions(+), 2 deletions(-) diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 261080dc9..6622b6dc1 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -1,9 +1,11 @@ import gc import logging import os +import random from pathlib import Path from typing import Callable, Iterator, List, Optional, Tuple +import numpy as np import torch import torch.distributed as dist import torch.nn as nn @@ -11,6 +13,7 @@ from torch.distributed.distributed_c10d import _get_default_group from torch.optim import Optimizer from torch.optim.lr_scheduler import _LRScheduler as LRScheduler from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler from colossalai.checkpoint_io import CheckpointIndexFile, CheckpointIO, GeneralCheckpointIO from colossalai.checkpoint_io.utils import ( @@ -448,6 +451,57 @@ class GeminiPlugin(DPPluginBase): def supported_devices(self) -> List[str]: return ["cuda", "npu"] + + def prepare_dataloader( + self, dataset, batch_size, shuffle=False, seed=1024, drop_last=False, pin_memory=False, num_workers=0, **kwargs + ): + r""" + Prepare a dataloader for distributed training. The dataloader will be wrapped by + `torch.utils.data.DataLoader` and `torch.utils.data.DistributedSampler`. + + + Args: + dataset (`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + zero_world_size = self.pg_mesh.size(ZERO_AXIS) + extra_dp_world_size = self.pg_mesh.size(DP_AXIS) + zero_rank = self.pg_mesh.coordinate(ZERO_AXIS) + extra_dp_rank = self.pg_mesh.coordinate(DP_AXIS) + sampler = DistributedSampler( + dataset, num_replicas=zero_world_size * extra_dp_world_size, rank=zero_rank * extra_dp_world_size + extra_dp_rank, shuffle=shuffle + ) + + # Deterministic dataloader + def seed_worker(worker_id): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + return DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) def configure( self, diff --git a/examples/language/llama2/benchmark.py b/examples/language/llama2/benchmark.py index d7a79a022..daf7d2fd4 100644 --- a/examples/language/llama2/benchmark.py +++ b/examples/language/llama2/benchmark.py @@ -72,6 +72,7 @@ def main(): parser.add_argument("--offload_optim_frac", type=float, default=0.0, help="Offload optim fraction. Only for gemini") parser.add_argument("--offload_param_frac", type=float, default=0.0, help="Offload param fraction. Only for gemini") parser.add_argument("--tp", type=int, default=1, help="Tensor parallel size") + parser.add_argument("--extra_dp", type=int, default=1, help="Extra data parallel size, used for Gemini") parser.add_argument("--pp", type=int, default=1, help="Pipeline parallel size") parser.add_argument("--mbs", type=int, default=1) parser.add_argument("--zero", type=int, default=0) @@ -93,9 +94,11 @@ def main(): shard_param_frac=args.shard_param_frac, offload_optim_frac=args.offload_optim_frac, offload_param_frac=args.offload_param_frac, + tp_size=args.tp, + extra_dp_size=args.extra_dp, ) elif args.plugin == "gemini_auto": - plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio) + plugin = GeminiPlugin(placement_policy="auto", precision="bf16", warmup_non_model_data_ratio=args.warmup_ratio, tp_size=args.tp, extra_dp_size=args.extra_dp) elif args.plugin == "fsdp": if use_empty_init: plugin = TorchFSDPPlugin( diff --git a/tests/kit/model_zoo/transformers/gptj.py b/tests/kit/model_zoo/transformers/gptj.py index 263978512..9eefbb43d 100644 --- a/tests/kit/model_zoo/transformers/gptj.py +++ b/tests/kit/model_zoo/transformers/gptj.py @@ -61,7 +61,7 @@ loss_fn = lambda x: x.loss config = transformers.GPTJConfig( n_layer=2, - n_head=16, + n_head=4, vocab_size=50258, attn_pdrop=0, embd_pdrop=0,