2023-09-24 15:12:26 +00:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
"""
|
2024-02-05 08:33:18 +00:00
|
|
|
Continual Pre-training/Supervised fine-tuning of Colossal-LLaMA-2 developed by Colossal-AI Team
|
2023-09-24 15:12:26 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
import argparse
|
2024-01-09 02:20:05 +00:00
|
|
|
import json
|
2023-09-24 15:12:26 +00:00
|
|
|
import os
|
|
|
|
import resource
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
2024-04-23 05:54:05 +00:00
|
|
|
from colossal_llama.dataset.loader import (
|
2024-01-09 02:20:05 +00:00
|
|
|
DataCollatorForSupervisedDataset,
|
|
|
|
StatefulDistributedSampler,
|
|
|
|
load_tokenized_dataset,
|
|
|
|
)
|
2024-04-23 05:54:05 +00:00
|
|
|
from colossal_llama.utils.ckpt_io import load_checkpoint, save_checkpoint
|
|
|
|
from colossal_llama.utils.flash_attention_patch import replace_with_flash_attention
|
|
|
|
from colossal_llama.utils.froze import freeze_non_embeds_parameters
|
|
|
|
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
|
2023-09-24 15:12:26 +00:00
|
|
|
from torch.utils.tensorboard import SummaryWriter
|
2024-01-09 02:20:05 +00:00
|
|
|
from tqdm import tqdm
|
2024-04-23 05:54:05 +00:00
|
|
|
from transformers import AutoTokenizer, LlamaForCausalLM
|
2023-09-24 15:12:26 +00:00
|
|
|
|
|
|
|
import colossalai
|
2024-02-05 08:33:18 +00:00
|
|
|
from colossalai.accelerator import get_accelerator
|
2023-09-24 15:12:26 +00:00
|
|
|
from colossalai.booster import Booster
|
2024-01-09 02:20:05 +00:00
|
|
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin
|
2023-09-24 15:12:26 +00:00
|
|
|
from colossalai.cluster import DistCoordinator
|
|
|
|
from colossalai.lazy import LazyInitContext
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
|
|
|
from colossalai.nn.optimizer import HybridAdam
|
2024-02-05 08:33:18 +00:00
|
|
|
from colossalai.utils import get_current_device
|
2023-09-24 15:12:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
def get_model_numel(model: torch.nn.Module) -> int:
|
|
|
|
return sum(p.numel() for p in model.parameters())
|
|
|
|
|
|
|
|
|
|
|
|
def format_numel_str(numel: int) -> str:
|
|
|
|
B = 1024**3
|
|
|
|
M = 1024**2
|
|
|
|
K = 1024
|
|
|
|
if numel >= B:
|
|
|
|
return f"{numel / B:.2f} B"
|
|
|
|
elif numel >= M:
|
|
|
|
return f"{numel / M:.2f} M"
|
|
|
|
elif numel >= K:
|
|
|
|
return f"{numel / K:.2f} K"
|
|
|
|
else:
|
|
|
|
return f"{numel}"
|
|
|
|
|
|
|
|
|
|
|
|
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor:
|
|
|
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM)
|
2024-03-11 05:49:58 +00:00
|
|
|
tensor = tensor.data
|
2023-09-24 15:12:26 +00:00
|
|
|
tensor.div_(dist.get_world_size())
|
|
|
|
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> None:
|
|
|
|
# ==============================
|
|
|
|
# Parse Arguments
|
|
|
|
# ==============================
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
parser.add_argument(
|
|
|
|
"--pretrained",
|
|
|
|
type=str,
|
|
|
|
default=None,
|
|
|
|
help="Address of the pre-trained modeling",
|
|
|
|
)
|
|
|
|
parser.add_argument("--dataset", nargs="+", default=[])
|
|
|
|
parser.add_argument(
|
|
|
|
"--plugin",
|
|
|
|
type=str,
|
|
|
|
default="gemini",
|
|
|
|
choices=["gemini", "gemini_auto", "zero2", "zero2_cpu", "3d"],
|
|
|
|
help="Choose which plugin to use",
|
|
|
|
)
|
|
|
|
parser.add_argument("--load_checkpoint", type=str, default=None, help="Load checkpoint")
|
|
|
|
parser.add_argument("--save_interval", type=int, default=1000, help="Save interval")
|
|
|
|
parser.add_argument("--save_dir", type=str, default="checkpoint_dir", help="Checkpoint directory")
|
|
|
|
parser.add_argument("--tensorboard_dir", type=str, default="logs_dir", help="Tensorboard directory")
|
|
|
|
parser.add_argument("--config_file", type=str, default="config_file", help="Config file")
|
|
|
|
parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs")
|
2024-02-05 08:33:18 +00:00
|
|
|
parser.add_argument("--accumulation_steps", type=int, default=1, help="Number of accumulation steps")
|
2023-09-24 15:12:26 +00:00
|
|
|
parser.add_argument("--micro_batch_size", type=int, default=2, help="Batch size of each process")
|
|
|
|
parser.add_argument("--lr", type=float, default=3e-4, help="Learning rate")
|
2024-04-23 05:54:05 +00:00
|
|
|
parser.add_argument("--max_length", type=int, default=8192, help="Model max length")
|
2023-09-24 15:12:26 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--mixed_precision",
|
|
|
|
type=str,
|
|
|
|
default="fp16",
|
|
|
|
choices=["fp16", "bf16"],
|
|
|
|
help="Mixed precision",
|
|
|
|
)
|
|
|
|
parser.add_argument("--grad_clip", type=float, default=1.0, help="Gradient clipping value")
|
|
|
|
parser.add_argument("--weight_decay", type=float, default=0.1, help="Weight decay")
|
|
|
|
parser.add_argument("--warmup_steps", type=int, default=None, help="Warmup steps")
|
|
|
|
parser.add_argument(
|
|
|
|
"--use_grad_checkpoint",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Use gradient checkpointing",
|
|
|
|
)
|
|
|
|
parser.add_argument(
|
|
|
|
"--use_flash_attn",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Use flash-attention",
|
|
|
|
)
|
2024-02-05 08:33:18 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--use_neft",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Use NEFTune",
|
|
|
|
)
|
2023-09-24 15:12:26 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--freeze_non_embeds_params",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="Freeze non embeddings parameters",
|
|
|
|
)
|
|
|
|
parser.add_argument("--tp", type=int, default=1)
|
|
|
|
parser.add_argument("--zero", type=int, default=1)
|
2024-02-05 08:33:18 +00:00
|
|
|
parser.add_argument("--pad_token", choices=["eos", "unk"], default="eos")
|
|
|
|
parser.add_argument("--padding_mode", choices=["max_length", "longest"], default="max_length")
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
parser.add_argument(
|
|
|
|
"--skip_save_each_epoch",
|
|
|
|
action="store_true",
|
|
|
|
default=False,
|
|
|
|
help="skip saving the model checkpoint after each epoch is completed.",
|
|
|
|
)
|
2023-09-24 15:12:26 +00:00
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
with open(args.config_file, "w") as f:
|
|
|
|
json.dump(args.__dict__, f, indent=4)
|
|
|
|
|
|
|
|
# ==============================
|
|
|
|
# Initialize Distributed Training
|
|
|
|
# ==============================
|
2024-04-29 02:40:11 +00:00
|
|
|
colossalai.launch_from_torch()
|
2024-02-05 08:33:18 +00:00
|
|
|
accelerator = get_accelerator()
|
2023-09-24 15:12:26 +00:00
|
|
|
coordinator = DistCoordinator()
|
|
|
|
|
|
|
|
# ==============================
|
|
|
|
# Initialize Tensorboard
|
|
|
|
# ==============================
|
|
|
|
if coordinator.is_master():
|
|
|
|
os.makedirs(args.tensorboard_dir, exist_ok=True)
|
|
|
|
writer = SummaryWriter(args.tensorboard_dir)
|
|
|
|
|
|
|
|
# ==============================
|
|
|
|
# Initialize Booster
|
|
|
|
# ==============================
|
|
|
|
if args.plugin == "gemini":
|
|
|
|
plugin = GeminiPlugin(
|
|
|
|
precision=args.mixed_precision,
|
|
|
|
initial_scale=2**16,
|
|
|
|
max_norm=args.grad_clip,
|
2024-02-19 08:41:04 +00:00
|
|
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
2023-09-24 15:12:26 +00:00
|
|
|
)
|
|
|
|
elif args.plugin == "gemini_auto":
|
|
|
|
plugin = GeminiPlugin(
|
|
|
|
precision=args.mixed_precision,
|
|
|
|
placement_policy="auto",
|
|
|
|
initial_scale=2**16,
|
|
|
|
max_norm=args.grad_clip,
|
2024-02-19 08:41:04 +00:00
|
|
|
enable_gradient_accumulation=(args.accumulation_steps > 1),
|
2023-09-24 15:12:26 +00:00
|
|
|
)
|
|
|
|
elif args.plugin == "zero2":
|
|
|
|
plugin = LowLevelZeroPlugin(
|
|
|
|
stage=2,
|
|
|
|
precision=args.mixed_precision,
|
|
|
|
initial_scale=2**16,
|
|
|
|
max_norm=args.grad_clip,
|
|
|
|
)
|
|
|
|
elif args.plugin == "zero2_cpu":
|
|
|
|
plugin = LowLevelZeroPlugin(
|
|
|
|
stage=2,
|
|
|
|
precision=args.mixed_precision,
|
|
|
|
initial_scale=2**16,
|
|
|
|
cpu_offload=True,
|
|
|
|
max_norm=args.grad_clip,
|
|
|
|
)
|
|
|
|
elif args.plugin == "3d":
|
|
|
|
plugin = HybridParallelPlugin(
|
|
|
|
tp_size=args.tp,
|
|
|
|
pp_size=1,
|
|
|
|
zero_stage=args.zero,
|
|
|
|
max_norm=args.grad_clip,
|
|
|
|
precision=args.mixed_precision,
|
|
|
|
)
|
|
|
|
else:
|
|
|
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
|
|
|
|
|
|
|
booster = Booster(plugin=plugin)
|
|
|
|
|
|
|
|
# ======================================================
|
|
|
|
# Initialize Tokenizer, Dataset, Collator and Dataloader
|
|
|
|
# ======================================================
|
2024-04-23 05:54:05 +00:00
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
2024-02-05 08:33:18 +00:00
|
|
|
if args.pad_token == "eos":
|
|
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
elif args.pad_token == "unk":
|
|
|
|
tokenizer.pad_token = tokenizer.unk_token
|
2023-09-24 15:12:26 +00:00
|
|
|
tokenizer.add_bos_token = False
|
|
|
|
tokenizer.add_eos_token = False
|
|
|
|
|
|
|
|
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
|
|
|
coordinator.print_on_master(f"Tensorboard logs will be saved at: {args.tensorboard_dir}")
|
|
|
|
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_dir}")
|
|
|
|
|
|
|
|
coordinator.print_on_master(f"Load dataset: {args.dataset}")
|
|
|
|
|
|
|
|
dataset = load_tokenized_dataset(dataset_paths=args.dataset, mode="train")
|
2024-02-05 08:33:18 +00:00
|
|
|
data_collator = DataCollatorForSupervisedDataset(
|
|
|
|
tokenizer=tokenizer, max_length=args.max_length, padding=args.padding_mode
|
|
|
|
)
|
2024-02-05 07:14:56 +00:00
|
|
|
dataloader = plugin.prepare_dataloader(
|
2023-09-24 15:12:26 +00:00
|
|
|
dataset=dataset,
|
|
|
|
batch_size=args.micro_batch_size,
|
|
|
|
shuffle=True,
|
|
|
|
drop_last=True,
|
|
|
|
collate_fn=data_collator,
|
2024-02-05 07:14:56 +00:00
|
|
|
distributed_sampler_cls=StatefulDistributedSampler,
|
2023-09-24 15:12:26 +00:00
|
|
|
)
|
|
|
|
coordinator.print_on_master(
|
2024-02-05 08:33:18 +00:00
|
|
|
f"Max device memory after data loader: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
2023-09-24 15:12:26 +00:00
|
|
|
)
|
|
|
|
|
|
|
|
# ======================================================
|
|
|
|
# Initialize Model, Objective, Optimizer and LR Scheduler
|
|
|
|
# ======================================================
|
2024-02-05 08:33:18 +00:00
|
|
|
init_ctx = (
|
|
|
|
LazyInitContext(default_device=get_current_device())
|
|
|
|
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
|
|
|
else nullcontext()
|
|
|
|
)
|
2023-09-24 15:12:26 +00:00
|
|
|
with init_ctx:
|
2024-02-06 11:02:37 +00:00
|
|
|
model = LlamaForCausalLM.from_pretrained(args.pretrained)
|
2023-09-24 15:12:26 +00:00
|
|
|
# Freeze part of parameters.
|
|
|
|
if args.freeze_non_embeds_params:
|
|
|
|
freeze_non_embeds_parameters(model=model)
|
2024-02-06 11:02:37 +00:00
|
|
|
# this is essential, otherwise the grad checkpoint will not work.
|
|
|
|
model.train()
|
2023-09-24 15:12:26 +00:00
|
|
|
|
|
|
|
if args.use_grad_checkpoint:
|
|
|
|
model.gradient_checkpointing_enable()
|
|
|
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
|
|
|
if args.use_flash_attn:
|
|
|
|
replace_with_flash_attention(model=model)
|
|
|
|
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
|
|
|
|
|
|
|
model_numel = get_model_numel(model)
|
|
|
|
coordinator.print_on_master(f"Model params: {format_numel_str(model_numel)}")
|
|
|
|
|
|
|
|
optimizer = HybridAdam(
|
2024-04-17 07:03:31 +00:00
|
|
|
model_params=(
|
|
|
|
filter(lambda p: p.requires_grad, model.parameters())
|
|
|
|
if args.freeze_non_embeds_params
|
|
|
|
else model.parameters()
|
|
|
|
),
|
2023-09-24 15:12:26 +00:00
|
|
|
lr=args.lr,
|
|
|
|
betas=(0.9, 0.95),
|
|
|
|
weight_decay=args.weight_decay,
|
|
|
|
adamw_mode=True,
|
|
|
|
)
|
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
if args.warmup_steps is None:
|
|
|
|
args.warmup_steps = int(args.num_epochs * 0.025 * (len(dataloader) // args.accumulation_steps))
|
|
|
|
coordinator.print_on_master(f"Warmup steps is set to {args.warmup_steps}")
|
|
|
|
|
2023-09-24 15:12:26 +00:00
|
|
|
lr_scheduler = CosineAnnealingWarmupLR(
|
|
|
|
optimizer=optimizer,
|
2024-02-05 08:33:18 +00:00
|
|
|
total_steps=args.num_epochs * (len(dataloader) // args.accumulation_steps),
|
|
|
|
warmup_steps=args.warmup_steps,
|
2023-09-24 15:12:26 +00:00
|
|
|
eta_min=0.1 * args.lr,
|
|
|
|
)
|
|
|
|
|
|
|
|
# Flash attention will be disabled because it does NOT support fp32.
|
|
|
|
default_dtype = torch.float16 if args.mixed_precision == "fp16" else torch.bfloat16
|
|
|
|
torch.set_default_dtype(default_dtype)
|
|
|
|
model, optimizer, _, dataloader, lr_scheduler = booster.boost(
|
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
lr_scheduler=lr_scheduler,
|
|
|
|
dataloader=dataloader,
|
|
|
|
)
|
|
|
|
|
|
|
|
torch.set_default_dtype(torch.float)
|
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
coordinator.print_on_master(
|
|
|
|
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
|
|
|
)
|
2023-09-24 15:12:26 +00:00
|
|
|
coordinator.print_on_master(
|
|
|
|
f"Booster init max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
|
|
|
)
|
|
|
|
|
|
|
|
start_epoch = 0
|
|
|
|
start_step = 0
|
|
|
|
sampler_start_idx = 0
|
|
|
|
if args.load_checkpoint is not None:
|
|
|
|
if "modeling" in args.load_checkpoint:
|
|
|
|
coordinator.print_on_master(f"Continued pretrain from checkpoint {args.load_checkpoint}")
|
|
|
|
booster.load_model(model, args.load_checkpoint)
|
|
|
|
else:
|
|
|
|
coordinator.print_on_master(f"Load model checkpoint from {args.load_checkpoint}")
|
|
|
|
start_epoch, start_step, sampler_start_idx = load_checkpoint(
|
|
|
|
load_dir=args.load_checkpoint,
|
|
|
|
booster=booster,
|
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
lr_scheduler=lr_scheduler,
|
|
|
|
)
|
|
|
|
coordinator.print_on_master(
|
|
|
|
f"Loaded checkpoint {args.load_checkpoint} at epoch {start_epoch} step {start_step}"
|
|
|
|
)
|
|
|
|
coordinator.print_on_master(f"Loaded sample at index {sampler_start_idx}")
|
|
|
|
|
|
|
|
coordinator.print_on_master(
|
2024-02-05 08:33:18 +00:00
|
|
|
f"Checkpoint loaded max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
2023-09-24 15:12:26 +00:00
|
|
|
)
|
|
|
|
coordinator.print_on_master(
|
2024-02-05 08:33:18 +00:00
|
|
|
f"Checkpoint loaded device memory: {accelerator.memory_allocated() / 1024 ** 2:.2f} MB"
|
2023-09-24 15:12:26 +00:00
|
|
|
)
|
|
|
|
coordinator.print_on_master(
|
|
|
|
f"Checkpoint loaded max CPU memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024:.2f} MB"
|
|
|
|
)
|
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
if args.use_neft:
|
|
|
|
coordinator.print_on_master("Activate NEFTune.")
|
|
|
|
model, handle = activate_neftune(model)
|
|
|
|
|
|
|
|
num_steps_per_epoch = len(dataloader) // args.accumulation_steps
|
2023-09-24 15:12:26 +00:00
|
|
|
# If resume training, set the sampler start index to the correct value
|
|
|
|
assert isinstance(dataloader.sampler, StatefulDistributedSampler)
|
|
|
|
dataloader.sampler.set_start_index(start_index=sampler_start_idx)
|
|
|
|
|
|
|
|
for epoch in range(start_epoch, args.num_epochs):
|
|
|
|
dataloader.sampler.set_epoch(epoch=epoch)
|
2024-02-06 03:52:17 +00:00
|
|
|
pbar = tqdm(
|
|
|
|
desc=f"Epoch {epoch}",
|
|
|
|
disable=not coordinator.is_master(),
|
|
|
|
total=num_steps_per_epoch,
|
|
|
|
initial=start_step // args.accumulation_steps,
|
|
|
|
)
|
2024-02-05 08:33:18 +00:00
|
|
|
total_loss = torch.tensor(0.0, device=get_current_device())
|
2024-02-05 10:04:23 +00:00
|
|
|
for step, batch in enumerate(dataloader, start=start_step):
|
2024-02-05 08:33:18 +00:00
|
|
|
batch = {k: v.to(get_current_device()) for k, v in batch.items() if isinstance(v, torch.Tensor)}
|
2023-09-24 15:12:26 +00:00
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
batch_output = model(**batch)
|
2023-09-24 15:12:26 +00:00
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
loss = batch_output.loss / args.accumulation_steps
|
|
|
|
total_loss.add_(loss.data)
|
2023-09-24 15:12:26 +00:00
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
booster.backward(loss=loss, optimizer=optimizer)
|
2023-09-24 15:12:26 +00:00
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
if (step + 1) % args.accumulation_steps == 0:
|
2023-09-24 15:12:26 +00:00
|
|
|
optimizer.step()
|
|
|
|
lr_scheduler.step()
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
all_reduce_mean(tensor=total_loss)
|
|
|
|
pbar.set_postfix({"Loss": f"{total_loss.item():.4f}"})
|
2023-09-24 15:12:26 +00:00
|
|
|
if coordinator.is_master():
|
2024-02-05 08:33:18 +00:00
|
|
|
global_step = (epoch * num_steps_per_epoch) + (step + 1) // args.accumulation_steps
|
|
|
|
writer.add_scalar(tag="Loss", scalar_value=total_loss.item(), global_step=global_step)
|
2023-09-24 15:12:26 +00:00
|
|
|
writer.add_scalar(
|
|
|
|
tag="Learning Rate",
|
|
|
|
scalar_value=lr_scheduler.get_last_lr()[0],
|
|
|
|
global_step=global_step,
|
|
|
|
)
|
2024-02-05 08:33:18 +00:00
|
|
|
total_loss.fill_(0.0)
|
|
|
|
pbar.update()
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
# Save modeling.
|
|
|
|
|
[FP8] rebase main (#5963)
* add SimPO
* fix dataloader
* remove debug code
* add orpo
* fix style
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix colossalai, transformers version
* fix torch colossalai version
* update transformers version
* [shardformer] DeepseekMoE support (#5871)
* [Feature] deepseek moe expert parallel implement
* [misc] fix typo, remove redundant file (#5867)
* [misc] fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] deepseek support & unit test
* [misc] remove debug code & useless print
* [misc] fix typos (#5872)
* [Feature] remove modeling file, use auto config. (#5884)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [Deepseek] remove redundant code (#5888)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [Feature/deepseek] resolve comment. (#5889)
* [misc] fix typos
* [Feature] deepseek support via auto model, remove modeling file
* [misc] delete useless file
* [misc] fix typos
* [misc] remove redundant code
* [misc] mv module replacement into if branch
* [misc] add some warning message and modify some code in unit test
* [misc] fix typos
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feat] Diffusion Model(PixArtAlpha/StableDiffusion3) Support (#5838)
* Diffusion Model Inference support
* Stable Diffusion 3 Support
* pixartalpha support
* [HotFix] CI,import,requirements-test for #5838 (#5892)
* [Hot Fix] CI,import,requirements-test
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [Feature] Enable PP + SP for llama (#5868)
* fix cross-PP-stage position id length diff bug
* fix typo
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* use a one cross entropy func for all shardformer models
---------
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [ShardFormer] Add Ulysses Sequence Parallelism support for Command-R, Qwen2 and ChatGLM (#5897)
* add benchmark for sft, dpo, simpo, orpo. Add benchmarking result. Support lora with gradient checkpoint
* fix style
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix eval
* hotfix citation
* [zero] support all-gather overlap (#5898)
* [zero] support all-gather overlap
* [zero] add overlap all-gather flag
* [misc] fix typo
* [zero] update api
* fix orpo cross entropy loss
* [Auto Parallel]: Speed up intra-op plan generation by 44% (#5446)
* Remove unnecessary calls to deepcopy
* Build DimSpec's difference dict only once
This change considerably speeds up construction speed of DimSpec objects. The difference_dict is the same for each DimSpec object, so a single copy of it is enough.
* Fix documentation of DimSpec's difference method
* [ShardFormer] fix qwen2 sp (#5903)
* [compatibility] support torch 2.2 (#5875)
* Support Pytorch 2.2.2
* keep build_on_pr file and update .compatibility
* fix object_to_tensor usage when torch>=2.3.0 (#5820)
* [misc] support torch2.3 (#5893)
* [misc] support torch2.3
* [devops] update compatibility ci
* [devops] update compatibility ci
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] add debug
* [devops] remove debug
* [devops] remove debug
* [release] update version (#5912)
* [plugin] support all-gather overlap for hybrid parallel (#5919)
* [plugin] fixed all-gather overlap support for hybrid parallel
* add kto
* fix style, add kto data sample
* [Examples] Add lazy init to OPT and GPT examples (#5924)
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [ColossalChat] Hotfix for ColossalChat (#5910)
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* add ignore and tiny llama
* fix path issue
* run style
* fix issue
* update bash
* fix ddp issue
* add Qwen 1.5 32B
* refactor tokenization
* [FIX BUG] UnboundLocalError: cannot access local variable 'default_conversation' where it is not associated with a value (#5931)
* cannot access local variable 'default_conversation' where it is not associated with a value
set default value for 'default_conversation'
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix test data
* refactor evaluation
* remove real data path
* remove real data path
* Add n_fused as an input from native_module (#5894)
* [FIX BUG] convert env param to int in (#5934)
* [Hotfix] Fix ZeRO typo #5936
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
* [Feature] Add a switch to control whether the model checkpoint needs to be saved after each epoch ends (#5941)
* Add a switch to control whether the model checkpoint needs to be saved after each epoch ends
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* fix style
* fix style
* fix style
* [shardformer] hotfix attn mask (#5945)
* [shardformer] hotfix attn mask (#5947)
* [Feat] Distrifusion Acceleration Support for Diffusion Inference (#5895)
* Distrifusion Support source
* comp comm overlap optimization
* sd3 benchmark
* pixart distrifusion bug fix
* sd3 bug fix and benchmark
* generation bug fix
* naming fix
* add docstring, fix counter and shape error
* add reference
* readme and requirement
* [zero] hotfix update master params (#5951)
* [release] update version (#5952)
* [Chat] Fix lora (#5946)
* fix merging
* remove filepath
* fix style
* Update README.md (#5958)
* [hotfix] Remove unused plan section (#5957)
* remove readme
* fix readme
* update
* [test] add mixtral for sequence classification
* [test] add mixtral transformer test
* [moe] fix plugin
* [test] mixtra pp shard test
* [chore] handle non member group
* [zero] solve hang
* [test] pass mixtral shardformer test
* [moe] implement transit between non moe tp and ep
* [zero] solve hang
* [misc] solve booster hang by rename the variable
* solve hang when parallel mode = pp + dp
* [moe] implement submesh initialization
* [moe] add mixtral dp grad scaling when not all experts are activated
* [chore] manually revert unintended commit
* [chore] trivial fix
* [chore] arg pass & remove drop token
* [test] add mixtral modelling test
* [moe] implement tp
* [moe] test deepseek
* [moe] clean legacy code
* [Feature] MoE Ulysses Support (#5918)
* moe sp support
* moe sp bug solve
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
---------
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [chore] minor fix
* [moe] init moe plugin comm setting with sp
* moe sp + ep bug fix
* [moe] finalize test (no pp)
* [moe] full test for deepseek and mixtral (pp + sp to fix)
* [chore] minor fix after rebase
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* [chore] solve moe ckpt test failure and some other arg pass failure
* [moe] remove ops
* [test] fix test: test_zero1_2
* [bug] fix: somehow logger hangs the program
* [moe] deepseek moe sp support
* [test] add check
* [deepseek] replace attn (a workaround for bug in transformers)
* [misc] skip redunant test
* [misc] remove debug/print code
* [moe] refactor mesh assignment
* Revert "[moe] implement submesh initialization"
This reverts commit 2f9bce6686d1415a83d5726dc5ff02222c742582.
* [chore] change moe_pg_mesh to private
* [misc] remove incompatible test config
* [misc] fix ci failure: change default value to false in moe plugin
* [misc] remove useless condition
* [chore] docstring
* [moe] remove force_overlap_comm flag and add warning instead
* [doc] add MoeHybridParallelPlugin docstring
* [moe] solve dp axis issue
* [chore] remove redundant test case, print string & reduce test tokens
* [feat] Dist Loader for Eval (#5950)
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* support auto distributed data loader
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix tp error
* remove unused parameters
* remove unused
* update inference
* update docs
* update inference
---------
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
* [lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin
* fix
* fix
* fix
* fix
* fp8 operators for compressed communication
cast_to_fp8, cast_from_fp8, all_reduce_fp8
* fix scaling algorithm in FP8 casting
* support fp8 communication in pipeline parallelism
* add fp8_communication flag in the script
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix typo
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* shardformer fp8
* fix rebase
* remove all to all
* fix shardformer fp8 communication training degradation
* [fp8] support all-gather flat tensor (#5932)
* [pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
* fix
* Update low_level_optim.py
---------
Co-authored-by: YeAnbang <anbangy2@outlook.com>
Co-authored-by: Haze188 <haze188@qq.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Edenzzzz <wenxuan.tan@wisc.edu>
Co-authored-by: Edenzzzz <wtan45@wisc.edu>
Co-authored-by: Runyu Lu <77330637+LRY89757@users.noreply.github.com>
Co-authored-by: Guangyao Zhang <xjtu521@qq.com>
Co-authored-by: YeAnbang <44796419+YeAnbang@users.noreply.github.com>
Co-authored-by: Hongxin Liu <lhx0217@gmail.com>
Co-authored-by: Stephan Kö <stephankoe@users.noreply.github.com>
Co-authored-by: アマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: Tong Li <tong.li352711588@gmail.com>
Co-authored-by: zhurunhua <1281592874@qq.com>
Co-authored-by: Insu Jang <insujang@umich.edu>
Co-authored-by: Gao, Ruiyuan <905370712@qq.com>
Co-authored-by: hxwang <wang1570@e.ntu.edu.sg>
Co-authored-by: Michelle <qianranma8@gmail.com>
Co-authored-by: Wang Binluo <32676639+wangbluo@users.noreply.github.com>
Co-authored-by: HangXu <hangxu0304@gmail.com>
2024-08-06 08:29:37 +00:00
|
|
|
save_model_condition = (
|
|
|
|
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
|
|
|
|
)
|
|
|
|
|
|
|
|
if not args.skip_save_each_epoch:
|
|
|
|
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
|
|
|
|
|
|
|
|
if save_model_condition:
|
2024-02-05 08:33:18 +00:00
|
|
|
coordinator.print_on_master("\nStart saving model checkpoint with running states")
|
|
|
|
|
|
|
|
if args.use_neft:
|
|
|
|
coordinator.print_on_master("Deactivate NEFTune before saving model.")
|
|
|
|
deactivate_neftune(model, handle)
|
|
|
|
|
2024-02-06 03:52:17 +00:00
|
|
|
accelerator.empty_cache()
|
2024-02-05 08:33:18 +00:00
|
|
|
save_checkpoint(
|
|
|
|
save_dir=args.save_dir,
|
|
|
|
booster=booster,
|
|
|
|
model=model,
|
|
|
|
optimizer=optimizer,
|
|
|
|
lr_scheduler=lr_scheduler,
|
|
|
|
epoch=epoch,
|
|
|
|
step=step + 1,
|
|
|
|
batch_size=args.micro_batch_size,
|
|
|
|
coordinator=coordinator,
|
|
|
|
)
|
|
|
|
coordinator.print_on_master(
|
|
|
|
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
|
|
|
|
)
|
|
|
|
|
|
|
|
if args.use_neft:
|
|
|
|
coordinator.print_on_master("Activate NEFTune.")
|
|
|
|
model, handle = activate_neftune(model)
|
|
|
|
|
|
|
|
# Delete cache.
|
|
|
|
# del batch, batch_labels, batch_output, loss
|
|
|
|
accelerator.empty_cache()
|
2023-09-24 15:12:26 +00:00
|
|
|
|
|
|
|
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
|
|
|
|
dataloader.sampler.set_start_index(start_index=0)
|
|
|
|
start_step = 0
|
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
if args.use_neft:
|
|
|
|
coordinator.print_on_master("Deactivate NEFTune.")
|
|
|
|
deactivate_neftune(model, handle)
|
|
|
|
|
2023-09-24 15:12:26 +00:00
|
|
|
# Final save.
|
|
|
|
coordinator.print_on_master("Start saving final model checkpoint")
|
|
|
|
booster.save_model(model, os.path.join(args.save_dir, "modeling"), shard=True)
|
2024-01-09 02:20:05 +00:00
|
|
|
coordinator.print_on_master(f"Saved final model checkpoint at epoch {epoch} at folder {args.save_dir}")
|
2023-09-24 15:12:26 +00:00
|
|
|
|
2024-02-05 08:33:18 +00:00
|
|
|
coordinator.print_on_master(f"Max device memory usage: {accelerator.max_memory_allocated()/1024**2:.2f} MB")
|
2023-09-24 15:12:26 +00:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|