From 7e65b718154025930151be2cef29782545b539d6 Mon Sep 17 00:00:00 2001 From: YeAnbang Date: Tue, 28 May 2024 03:14:37 +0000 Subject: [PATCH] run pre-commit --- .../examples/training_scripts/train_sft.py | 42 ++++++++++--------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/applications/ColossalChat/examples/training_scripts/train_sft.py b/applications/ColossalChat/examples/training_scripts/train_sft.py index 2bdcfb833..e1af3873c 100755 --- a/applications/ColossalChat/examples/training_scripts/train_sft.py +++ b/applications/ColossalChat/examples/training_scripts/train_sft.py @@ -1,8 +1,10 @@ import argparse +import inspect import json import math import os import resource +import sys from contextlib import nullcontext import torch @@ -14,17 +16,15 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import colossalai from colossalai.booster import Booster -from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, TorchDDPPlugin, LowLevelZeroPlugin +from colossalai.booster.plugin import GeminiPlugin, HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin from colossalai.cluster import DistCoordinator +from colossalai.logging import get_dist_logger from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR from colossalai.nn.optimizer import HybridAdam -import inspect -import sys -import torch.distributed as dist -from colossalai.logging import get_dist_logger logger = get_dist_logger() + def train(args): print(colossalai.__version__, inspect.getfile(colossalai)) print(sys.executable) @@ -44,23 +44,28 @@ def train(args): # ============================== init_ctx = nullcontext() with init_ctx: - model = AutoModelForCausalLM.from_pretrained(args.pretrain, - torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, - trust_remote_code=True) - # check if the hybrid parallel plugin is compatible with the model + model = AutoModelForCausalLM.from_pretrained( + args.pretrain, + torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, + trust_remote_code=True, + ) + # check if the hybrid parallel plugin is compatible with the model try: from colossalai.shardformer.policies.auto_policy import get_autopolicy + policy = get_autopolicy(model) if policy is not None: - if args.plugin in ['zero2', 'zero2_cpu']: + if args.plugin in ["zero2", "zero2_cpu"]: # if compatible, set the plugin to hybrid, which use colo-attention - args.plugin = '3d' + args.plugin = "3d" args.zero_stage = 2 - if args.plugin == 'zero2_cpu': + if args.plugin == "zero2_cpu": args.zero_cpu_offload = True else: args.zero_cpu_offload = False - logger.info(f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}") + logger.info( + f"Model is compatible with hybrid parallel plugin, set plugin to {args.plugin} with zero_stage {args.zero_stage} and zero_cpu_offload {args.zero_cpu_offload}" + ) except NotImplementedError: logger.warning(f"Unable to find a policy for the model, use {args.plugin} plugin instead") if args.use_flash_attn: @@ -69,10 +74,10 @@ def train(args): args.pretrain, torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16, attn_implementation="flash_attention_2", - trust_remote_code=True + trust_remote_code=True, ) if args.lora_rank > 0: - model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) + model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias) if args.plugin == "ddp": """ @@ -87,7 +92,7 @@ def train(args): initial_scale=2**16, max_norm=args.grad_clip, enable_gradient_accumulation=True if args.accumulation_steps > 1 else False, - enable_flash_attention=args.use_flash_attn + enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -95,7 +100,7 @@ def train(args): placement_policy="auto", initial_scale=2**16, max_norm=args.grad_clip, - enable_flash_attention=args.use_flash_attn + enable_flash_attention=args.use_flash_attn, ) elif args.plugin == "zero2": plugin = LowLevelZeroPlugin( @@ -121,7 +126,7 @@ def train(args): zero_stage=args.zero_stage, enable_flash_attention=args.use_flash_attn, enable_sequence_parallelism=True if args.sp > 1 else False, - cpu_offload=True if args.zero_stage>=1 and args.zero_cpu_offload else False, + cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False, parallel_output=False, max_norm=args.grad_clip, precision=args.mixed_precision, @@ -139,7 +144,6 @@ def train(args): # LazyInitContext(default_device=get_current_device()) if isinstance(plugin, (GeminiPlugin,)) else nullcontext() # ) - if args.grad_checkpoint and args.lora_rank == 0: # lora layers are not supported by gradient checkpointing model.gradient_checkpointing_enable()