run pre-commit

pull/5759/head
YeAnbang 2024-05-28 03:14:37 +00:00
parent 929e1e3da4
commit 7e65b71815
1 changed files with 23 additions and 19 deletions

View File

@ -1,8 +1,10 @@
import argparse import argparse
import inspect
import json import json
import math import math
import os import os
import resource import resource
import sys
from contextlib import nullcontext from contextlib import nullcontext
import torch import torch
@ -14,17 +16,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai import colossalai
from colossalai.booster import Booster from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
import inspect
import sys
import torch.distributed as dist
from colossalai.logging import get_dist_logger
logger = get_dist_logger() logger = get_dist_logger()
def train(args): def train(args):
print(colossalai.__version__, inspect.getfile(colossalai)) print(colossalai.__version__, inspect.getfile(colossalai))
print(sys.executable) print(sys.executable)
@ -44,23 +44,28 @@ def train(args):
# ============================== # ==============================
init_ctx = nullcontext() init_ctx = nullcontext()
with init_ctx: with init_ctx:
model = AutoModelForCausalLM.from_pretrained(args.pretrain, model = AutoModelForCausalLM.from_pretrained(
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, args.pretrain,
trust_remote_code=True) torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
# check if the hybrid parallel plugin is compatible with the model trust_remote_code=True,
)
# check if the hybrid parallel plugin is compatible with the model
try: try:
from colossalai.shardformer.policies.auto_policy import get_autopolicy from colossalai.shardformer.policies.auto_policy import get_autopolicy
policy = get_autopolicy(model) policy = get_autopolicy(model)
if policy is not None: if policy is not None:
if args.plugin in ['zero2', 'zero2_cpu']: if args.plugin in ["zero2", "zero2_cpu"]:
# if compatible, set the plugin to hybrid, which use colo-attention # if compatible, set the plugin to hybrid, which use colo-attention
args.plugin = '3d' args.plugin = "3d"
args.zero_stage = 2 args.zero_stage = 2
if args.plugin == 'zero2_cpu': if args.plugin == "zero2_cpu":
args.zero_cpu_offload = True args.zero_cpu_offload = True
else: else:
args.zero_cpu_offload = False args.zero_cpu_offload = False
logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}") logger.info(
f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}"
)
except NotImplementedError: except NotImplementedError:
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead") logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
if args.use_flash_attn: if args.use_flash_attn:
@ -69,10 +74,10 @@ def train(args):
args.pretrain, args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2", attn_implementation="flash_attention_2",
trust_remote_code=True trust_remote_code=True,
) )
if args.lora_rank > 0: if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.plugin == "ddp": if args.plugin == "ddp":
""" """
@ -87,7 +92,7 @@ def train(args):
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False, enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
enable_flash_attention=args.use_flash_attn enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "gemini_auto": elif args.plugin == "gemini_auto":
plugin = GeminiPlugin( plugin = GeminiPlugin(
@ -95,7 +100,7 @@ def train(args):
placement_policy="auto", placement_policy="auto",
initial_scale=2**16, initial_scale=2**16,
max_norm=args.grad_clip, max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn enable_flash_attention=args.use_flash_attn,
) )
elif args.plugin == "zero2": elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin( plugin = LowLevelZeroPlugin(
@ -121,7 +126,7 @@ def train(args):
zero_stage=args.zero_stage, zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn, enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False, enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False, cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False, parallel_output=False,
max_norm=args.grad_clip, max_norm=args.grad_clip,
precision=args.mixed_precision, precision=args.mixed_precision,
@ -139,7 +144,6 @@ def train(args):
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# ) # )
if args.grad_checkpoint and args.lora_rank == 0: if args.grad_checkpoint and args.lora_rank == 0:
# lora layers are not supported by gradient checkpointing # lora layers are not supported by gradient checkpointing
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()