run pre-commit

pull/5759/head
YeAnbang 2024-05-28 03:14:37 +00:00
parent 929e1e3da4
commit 7e65b71815
1 changed files with 23 additions and 19 deletions

View File

@ -1,8 +1,10 @@
import argparse
import inspect
import json
import math
import os
import resource
import sys
from contextlib import nullcontext
import torch
@ -14,17 +16,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.cluster import DistCoordinator
from colossalai.logging import get_dist_logger
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
from colossalai.nn.optimizer import HybridAdam
import inspect
import sys
import torch.distributed as dist
from colossalai.logging import get_dist_logger
logger = get_dist_logger()
def train(args):
print(colossalai.__version__, inspect.getfile(colossalai))
print(sys.executable)
@ -44,23 +44,28 @@ def train(args):
# ==============================
init_ctx = nullcontext()
with init_ctx:
model = AutoModelForCausalLM.from_pretrained(args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True)
# check if the hybrid parallel plugin is compatible with the model
model = AutoModelForCausalLM.from_pretrained(
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# check if the hybrid parallel plugin is compatible with the model
try:
from colossalai.shardformer.policies.auto_policy import get_autopolicy
policy = get_autopolicy(model)
if policy is not None:
if args.plugin in ['zero2', 'zero2_cpu']:
if args.plugin in ["zero2", "zero2_cpu"]:
# if compatible, set the plugin to hybrid, which use colo-attention
args.plugin = '3d'
args.plugin = "3d"
args.zero_stage = 2
if args.plugin == 'zero2_cpu':
if args.plugin == "zero2_cpu":
args.zero_cpu_offload = True
else:
args.zero_cpu_offload = False
logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
logger.info(
f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}"
)
except NotImplementedError:
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
if args.use_flash_attn:
@ -69,10 +74,10 @@ def train(args):
args.pretrain,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
attn_implementation="flash_attention_2",
trust_remote_code=True
trust_remote_code=True,
)
if args.lora_rank > 0:
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
if args.plugin == "ddp":
"""
@ -87,7 +92,7 @@ def train(args):
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -95,7 +100,7 @@ def train(args):
placement_policy="auto",
initial_scale=2**16,
max_norm=args.grad_clip,
enable_flash_attention=args.use_flash_attn
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
plugin = LowLevelZeroPlugin(
@ -121,7 +126,7 @@ def train(args):
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_sequence_parallelism=True if args.sp > 1 else False,
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
@ -139,7 +144,6 @@ def train(args):
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
# )
if args.grad_checkpoint and args.lora_rank == 0:
# lora layers are not supported by gradient checkpointing
model.gradient_checkpointing_enable()