mirror of https://github.com/hpcaitech/ColossalAI
run pre-commit
parent
929e1e3da4
commit
7e65b71815
|
@ -1,8 +1,10 @@
|
||||||
import argparse
|
import argparse
|
||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import resource
|
import resource
|
||||||
|
import sys
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -14,17 +16,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.booster import Booster
|
from colossalai.booster import Booster
|
||||||
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin
|
from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||||
from colossalai.cluster import DistCoordinator
|
from colossalai.cluster import DistCoordinator
|
||||||
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR
|
||||||
from colossalai.nn.optimizer import HybridAdam
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
import inspect
|
|
||||||
import sys
|
|
||||||
import torch.distributed as dist
|
|
||||||
from colossalai.logging import get_dist_logger
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
print(colossalai.__version__, inspect.getfile(colossalai))
|
print(colossalai.__version__, inspect.getfile(colossalai))
|
||||||
print(sys.executable)
|
print(sys.executable)
|
||||||
|
@ -44,23 +44,28 @@ def train(args):
|
||||||
# ==============================
|
# ==============================
|
||||||
init_ctx = nullcontext()
|
init_ctx = nullcontext()
|
||||||
with init_ctx:
|
with init_ctx:
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain,
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
args.pretrain,
|
||||||
trust_remote_code=True)
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
# check if the hybrid parallel plugin is compatible with the model
|
trust_remote_code=True,
|
||||||
|
)
|
||||||
|
# check if the hybrid parallel plugin is compatible with the model
|
||||||
try:
|
try:
|
||||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||||
|
|
||||||
policy = get_autopolicy(model)
|
policy = get_autopolicy(model)
|
||||||
if policy is not None:
|
if policy is not None:
|
||||||
if args.plugin in ['zero2', 'zero2_cpu']:
|
if args.plugin in ["zero2", "zero2_cpu"]:
|
||||||
# if compatible, set the plugin to hybrid, which use colo-attention
|
# if compatible, set the plugin to hybrid, which use colo-attention
|
||||||
args.plugin = '3d'
|
args.plugin = "3d"
|
||||||
args.zero_stage = 2
|
args.zero_stage = 2
|
||||||
if args.plugin == 'zero2_cpu':
|
if args.plugin == "zero2_cpu":
|
||||||
args.zero_cpu_offload = True
|
args.zero_cpu_offload = True
|
||||||
else:
|
else:
|
||||||
args.zero_cpu_offload = False
|
args.zero_cpu_offload = False
|
||||||
logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}")
|
logger.info(
|
||||||
|
f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}"
|
||||||
|
)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
|
logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead")
|
||||||
if args.use_flash_attn:
|
if args.use_flash_attn:
|
||||||
|
@ -69,10 +74,10 @@ def train(args):
|
||||||
args.pretrain,
|
args.pretrain,
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
attn_implementation="flash_attention_2",
|
attn_implementation="flash_attention_2",
|
||||||
trust_remote_code=True
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
if args.lora_rank > 0:
|
if args.lora_rank > 0:
|
||||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
||||||
|
|
||||||
if args.plugin == "ddp":
|
if args.plugin == "ddp":
|
||||||
"""
|
"""
|
||||||
|
@ -87,7 +92,7 @@ def train(args):
|
||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
|
enable_gradient_accumulation=True if args.accumulation_steps > 1 else False,
|
||||||
enable_flash_attention=args.use_flash_attn
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
|
@ -95,7 +100,7 @@ def train(args):
|
||||||
placement_policy="auto",
|
placement_policy="auto",
|
||||||
initial_scale=2**16,
|
initial_scale=2**16,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
enable_flash_attention=args.use_flash_attn
|
enable_flash_attention=args.use_flash_attn,
|
||||||
)
|
)
|
||||||
elif args.plugin == "zero2":
|
elif args.plugin == "zero2":
|
||||||
plugin = LowLevelZeroPlugin(
|
plugin = LowLevelZeroPlugin(
|
||||||
|
@ -121,7 +126,7 @@ def train(args):
|
||||||
zero_stage=args.zero_stage,
|
zero_stage=args.zero_stage,
|
||||||
enable_flash_attention=args.use_flash_attn,
|
enable_flash_attention=args.use_flash_attn,
|
||||||
enable_sequence_parallelism=True if args.sp > 1 else False,
|
enable_sequence_parallelism=True if args.sp > 1 else False,
|
||||||
cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False,
|
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
|
||||||
parallel_output=False,
|
parallel_output=False,
|
||||||
max_norm=args.grad_clip,
|
max_norm=args.grad_clip,
|
||||||
precision=args.mixed_precision,
|
precision=args.mixed_precision,
|
||||||
|
@ -139,7 +144,6 @@ def train(args):
|
||||||
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
# LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext()
|
||||||
# )
|
# )
|
||||||
|
|
||||||
|
|
||||||
if args.grad_checkpoint and args.lora_rank == 0:
|
if args.grad_checkpoint and args.lora_rank == 0:
|
||||||
# lora layers are not supported by gradient checkpointing
|
# lora layers are not supported by gradient checkpointing
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
|
|
Loading…
Reference in New Issue