From 00046a575072d3cede2eb7d51335d90db0df2512 Mon Sep 17 00:00:00 2001 From: sxl1993 <1218197792@qq.com> Date: Fri, 25 Aug 2023 06:14:28 -0700 Subject: [PATCH] add code --- ptuning/ds_train_finetune.sh | 20 +- ptuning/main_parallel.py | 436 +++++++++ ptuning/parallel_train.sh | 28 + ptuning/train.sh | 7 +- ptuning/trainer.py | 1628 ++++++++++++++++++---------------- ptuning/trainer_seq2seq.py | 54 +- requirements.txt | 9 +- 7 files changed, 1388 insertions(+), 794 deletions(-) create mode 100644 ptuning/main_parallel.py create mode 100644 ptuning/parallel_train.sh diff --git a/ptuning/ds_train_finetune.sh b/ptuning/ds_train_finetune.sh index 531a800..32667ea 100644 --- a/ptuning/ds_train_finetune.sh +++ b/ptuning/ds_train_finetune.sh @@ -1,28 +1,28 @@ +PRE_SEQ_LEN=128 +LR=2e-2 -LR=1e-4 - -MASTER_PORT=$(shuf -n 1 -i 10000-65535) - -deepspeed --num_gpus=4 --master_port $MASTER_PORT main.py \ +deepspeed --include="localhost:0,1" main.py \ --deepspeed deepspeed.json \ --do_train \ --train_file AdvertiseGen/train.json \ --test_file AdvertiseGen/dev.json \ --prompt_column content \ --response_column summary \ + --preprocessing_num_workers 10 \ --overwrite_cache \ --model_name_or_path THUDM/chatglm-6b \ - --output_dir ./output/adgen-chatglm-6b-ft-$LR \ + --output_dir ./output/ds-chatglm-6b-ptuning-$LR \ --overwrite_output_dir \ --max_source_length 64 \ --max_target_length 64 \ - --per_device_train_batch_size 4 \ + --per_device_train_batch_size 2 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 1 \ --predict_with_generate \ - --max_steps 5000 \ + --max_steps 1000 \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate $LR \ - --fp16 - + --pre_seq_len $PRE_SEQ_LEN \ + --save_total_limit 1 \ + --fp16 \ No newline at end of file diff --git a/ptuning/main_parallel.py b/ptuning/main_parallel.py new file mode 100644 index 0000000..59c1120 --- /dev/null +++ b/ptuning/main_parallel.py @@ -0,0 +1,436 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence. +""" +# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. + +import logging +import os +import sys +import json + +import numpy as np +from datasets import load_dataset +import jieba +from rouge_chinese import Rouge +from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction +import torch + +import transformers +from transformers import ( + AutoConfig, + AutoModel, + AutoTokenizer, + DataCollatorForSeq2Seq, + HfArgumentParser, + Seq2SeqTrainingArguments, + set_seed, +) +from trainer_seq2seq import Seq2SeqTrainer + +from arguments import ModelArguments, DataTrainingArguments + +logger = logging.getLogger(__name__) + +def main(): + + parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + + if training_args.should_log: + # The default of training_args.log_level is passive, so we set log level at info here to have that default. + transformers.utils.logging.set_verbosity_info() + + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + # datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" + + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + transformers.logging.set_verbosity_info() + # Set seed before initializing model. + set_seed(training_args.seed) + + # Load dataset + data_files = {} + if data_args.train_file is not None: + data_files["train"] = data_args.train_file + extension = data_args.train_file.split(".")[-1] + if data_args.validation_file is not None: + data_files["validation"] = data_args.validation_file + extension = data_args.validation_file.split(".")[-1] + if data_args.test_file is not None: + data_files["test"] = data_args.test_file + extension = data_args.test_file.split(".")[-1] + + raw_datasets = load_dataset( + extension, + data_files=data_files, + cache_dir=model_args.cache_dir, + use_auth_token=True if model_args.use_auth_token else None, + ) + + # Load pretrained model and tokenizer + config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + config.pre_seq_len = model_args.pre_seq_len + config.prefix_projection = model_args.prefix_projection + + tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) + + if model_args.ptuning_checkpoint is not None: + # Evaluation + # Loading extra state dict of prefix encoder + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) + prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) + new_prefix_state_dict = {} + for k, v in prefix_state_dict.items(): + if k.startswith("transformer.prefix_encoder."): + new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v + model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) + else: + model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, + trust_remote_code=True, device_map="auto") + + model.gradient_checkpointing_enable() + model.enable_input_require_grads() + model.is_parallelizable = True + model.model_parallel = True + model.config.use_cache = False + + if model_args.quantization_bit is not None: + print(f"Quantized to {model_args.quantization_bit} bit") + model = model.quantize(model_args.quantization_bit) + if model_args.pre_seq_len is not None: + # P-tuning v2 + model = model.half() + model.transformer.prefix_encoder.float() + else: + # Finetune + model = model.float() + + prefix = data_args.source_prefix if data_args.source_prefix is not None else "" + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + if training_args.do_train: + column_names = raw_datasets["train"].column_names + elif training_args.do_eval: + column_names = raw_datasets["validation"].column_names + elif training_args.do_predict: + column_names = raw_datasets["test"].column_names + else: + logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") + return + + # Get the column names for input/target. + prompt_column = data_args.prompt_column + response_column = data_args.response_column + history_column = data_args.history_column + + # Temporarily set max_target_length for training. + max_target_length = data_args.max_target_length + + def preprocess_function_eval(examples): + inputs, targets = [], [] + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query = examples[prompt_column][i] + if history_column is None or len(examples[history_column][i]) == 0: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + inputs.append(prompt) + targets.append(examples[response_column][i]) + + inputs = [prefix + inp for inp in inputs] + model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, truncation=True, padding=True) + labels = tokenizer(text_target=targets, max_length=max_target_length, truncation=True) + + if data_args.ignore_pad_token_for_loss: + labels["input_ids"] = [ + [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] + ] + model_inputs["labels"] = labels["input_ids"] + + return model_inputs + + def preprocess_function_train(examples): + max_seq_length = data_args.max_source_length + data_args.max_target_length + + model_inputs = { + "input_ids": [], + "labels": [], + } + for i in range(len(examples[prompt_column])): + if examples[prompt_column][i] and examples[response_column][i]: + query, answer = examples[prompt_column][i], examples[response_column][i] + + if history_column is None: + prompt = query + else: + prompt = "" + history = examples[history_column][i] + for turn_idx, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) + prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) + + prompt = prefix + prompt + a_ids = tokenizer.encode(text=prompt, add_special_tokens=False) + b_ids = tokenizer.encode(text=answer, add_special_tokens=False) + + if len(a_ids) > data_args.max_source_length - 1: + a_ids = a_ids[: data_args.max_source_length - 1] + + if len(b_ids) > data_args.max_target_length - 2: + b_ids = b_ids[: data_args.max_target_length - 2] + + input_ids = tokenizer.build_inputs_with_special_tokens(a_ids, b_ids) + + context_length = input_ids.index(tokenizer.bos_token_id) + mask_position = context_length - 1 + labels = [-100] * context_length + input_ids[mask_position+1:] + + pad_len = max_seq_length - len(input_ids) + input_ids = input_ids + [tokenizer.pad_token_id] * pad_len + labels = labels + [tokenizer.pad_token_id] * pad_len + if data_args.ignore_pad_token_for_loss: + labels = [(l if l != tokenizer.pad_token_id else -100) for l in labels] + + model_inputs["input_ids"].append(input_ids) + model_inputs["labels"].append(labels) + + return model_inputs + + def print_dataset_example(example): + print("input_ids",example["input_ids"]) + print("inputs", tokenizer.decode(example["input_ids"])) + print("label_ids", example["labels"]) + print("labels", tokenizer.decode(example["labels"])) + + if training_args.do_train: + if "train" not in raw_datasets: + raise ValueError("--do_train requires a train dataset") + train_dataset = raw_datasets["train"] + if data_args.max_train_samples is not None: + max_train_samples = min(len(train_dataset), data_args.max_train_samples) + train_dataset = train_dataset.select(range(max_train_samples)) + with training_args.main_process_first(desc="train dataset map pre-processing"): + train_dataset = train_dataset.map( + preprocess_function_train, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on train dataset", + ) + print_dataset_example(train_dataset[0]) + + if training_args.do_eval: + max_target_length = data_args.val_max_target_length + if "validation" not in raw_datasets: + raise ValueError("--do_eval requires a validation dataset") + eval_dataset = raw_datasets["validation"] + if data_args.max_eval_samples is not None: + max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) + eval_dataset = eval_dataset.select(range(max_eval_samples)) + with training_args.main_process_first(desc="validation dataset map pre-processing"): + eval_dataset = eval_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on validation dataset", + ) + print_dataset_example(eval_dataset[0]) + + if training_args.do_predict: + max_target_length = data_args.val_max_target_length + if "test" not in raw_datasets: + raise ValueError("--do_predict requires a test dataset") + predict_dataset = raw_datasets["test"] + if data_args.max_predict_samples is not None: + max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) + predict_dataset = predict_dataset.select(range(max_predict_samples)) + with training_args.main_process_first(desc="prediction dataset map pre-processing"): + predict_dataset = predict_dataset.map( + preprocess_function_eval, + batched=True, + num_proc=data_args.preprocessing_num_workers, + remove_columns=column_names, + load_from_cache_file=not data_args.overwrite_cache, + desc="Running tokenizer on prediction dataset", + ) + print_dataset_example(predict_dataset[0]) + + # Data collator + label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id + data_collator = DataCollatorForSeq2Seq( + tokenizer, + model=model, + label_pad_token_id=label_pad_token_id, + pad_to_multiple_of=None, + padding=False + ) + + # Metric + def compute_metrics(eval_preds): + preds, labels = eval_preds + if isinstance(preds, tuple): + preds = preds[0] + decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) + if data_args.ignore_pad_token_for_loss: + # Replace -100 in the labels as we can't decode them. + labels = np.where(labels != -100, labels, tokenizer.pad_token_id) + decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) + + score_dict = { + "rouge-1": [], + "rouge-2": [], + "rouge-l": [], + "bleu-4": [] + } + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + rouge = Rouge() + scores = rouge.get_scores(' '.join(hypothesis) , ' '.join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + for k, v in score_dict.items(): + score_dict[k] = float(np.mean(v)) + return score_dict + + # Override the decoding parameters of Seq2SeqTrainer +# training_args.generation_max_length = ( +# training_args.generation_max_length +# if training_args.generation_max_length is not None +# else data_args.val_max_target_length +# ) +# training_args.generation_num_beams = ( +# data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams +# ) + # Initialize our Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=train_dataset if training_args.do_train else None, + eval_dataset=eval_dataset if training_args.do_eval else None, + tokenizer=tokenizer, + data_collator=data_collator, + compute_metrics=compute_metrics if training_args.predict_with_generate else None, + save_prefixencoder=model_args.pre_seq_len is not None + ) + + # Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + # elif last_checkpoint is not None: + # checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + # trainer.save_model() # Saves the tokenizer too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) + ) + metrics["train_samples"] = min(max_train_samples, len(train_dataset)) + + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # Evaluation + results = {} + max_seq_length = data_args.max_source_length + data_args.max_target_length + 1 + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate(metric_key_prefix="eval", do_sample=True, top_p=0.7, max_length=max_seq_length, temperature=0.95) + max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) + metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + if training_args.do_predict: + logger.info("*** Predict ***") + predict_results = trainer.predict(predict_dataset, metric_key_prefix="predict", max_length=max_seq_length, do_sample=True, top_p=0.7, temperature=0.95) + metrics = predict_results.metrics + max_predict_samples = ( + data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) + ) + metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) + + trainer.log_metrics("predict", metrics) + trainer.save_metrics("predict", metrics) + + if trainer.is_world_process_zero(): + if training_args.predict_with_generate: + predictions = tokenizer.batch_decode( + predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + predictions = [pred.strip() for pred in predictions] + labels = tokenizer.batch_decode( + predict_results.label_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + labels = [label.strip() for label in labels] + output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") + with open(output_prediction_file, "w", encoding="utf-8") as writer: + for p, l in zip(predictions, labels): + res = json.dumps({"labels": l, "predict": p}, ensure_ascii=False) + writer.write(f"{res}\n") + return results + + +def _mp_fn(index): + # For xla_spawn (TPUs) + main() + + +if __name__ == "__main__": + main() diff --git a/ptuning/parallel_train.sh b/ptuning/parallel_train.sh new file mode 100644 index 0000000..8d4c7a2 --- /dev/null +++ b/ptuning/parallel_train.sh @@ -0,0 +1,28 @@ +PRE_SEQ_LEN=128 +LR=2e-2 + +CUDA_VISIBLE_DEVICES=1,2 python3 main_parallel.py \ + --do_train \ + --train_file AdvertiseGen/train.json \ + --test_file AdvertiseGen/dev.json \ + --prompt_column content \ + --response_column summary \ + --preprocessing_num_workers 10 \ + --overwrite_cache \ + --model_name_or_path THUDM/chatglm-6b \ + --output_dir ./output/parallel-chatglm-6b-ptuning-$LR \ + --overwrite_output_dir \ + --max_source_length 64 \ + --max_target_length 64 \ + --per_device_train_batch_size 2 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 1 \ + --predict_with_generate \ + --max_steps 1000 \ + --logging_steps 10 \ + --save_steps 1000 \ + --learning_rate $LR \ + --pre_seq_len $PRE_SEQ_LEN \ + --save_total_limit 1 \ + --gradient_checkpointing \ + --fp16 \ No newline at end of file diff --git a/ptuning/train.sh b/ptuning/train.sh index efc9a16..fcad903 100644 --- a/ptuning/train.sh +++ b/ptuning/train.sh @@ -3,10 +3,12 @@ LR=2e-2 CUDA_VISIBLE_DEVICES=0 python3 main.py \ --do_train \ + --fp16 \ --train_file AdvertiseGen/train.json \ --validation_file AdvertiseGen/dev.json \ --prompt_column content \ --response_column summary \ + --preprocessing_num_workers 10 \ --overwrite_cache \ --model_name_or_path THUDM/chatglm-6b \ --output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ @@ -21,6 +23,5 @@ CUDA_VISIBLE_DEVICES=0 python3 main.py \ --logging_steps 10 \ --save_steps 1000 \ --learning_rate $LR \ - --pre_seq_len $PRE_SEQ_LEN \ - --quantization_bit 4 - + --pre_seq_len $PRE_SEQ_LEN + # --quantization_bit 4 diff --git a/ptuning/trainer.py b/ptuning/trainer.py index 63101bc..4f49825 100644 --- a/ptuning/trainer.py +++ b/ptuning/trainer.py @@ -17,8 +17,10 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune """ import contextlib +import copy import functools import glob +import importlib.metadata import inspect import math import os @@ -29,54 +31,43 @@ import sys import time import warnings from collections.abc import Mapping -from distutils.util import strtobool from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union -from tqdm.auto import tqdm - # Integrations must be imported before ML frameworks: # isort: off -from transformers.integrations import ( - default_hp_search_backend, +from .integrations import ( get_reporting_integration_callbacks, hp_params, is_fairscale_available, - is_optuna_available, - is_ray_tune_available, - is_sigopt_available, - is_wandb_available, - run_hp_search_optuna, - run_hp_search_ray, - run_hp_search_sigopt, - run_hp_search_wandb, ) # isort: on +import huggingface_hub.utils as hf_hub_utils import numpy as np import torch import torch.distributed as dist -from huggingface_hub import Repository, create_repo +from huggingface_hub import Repository, create_repo, upload_folder from packaging import version from torch import nn from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler -from torch.utils.data.distributed import DistributedSampler -from transformers import __version__ -from transformers.configuration_utils import PretrainedConfig -from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator -from transformers.debug_utils import DebugOption, DebugUnderflowOverflow -from transformers.deepspeed import deepspeed_init, is_deepspeed_zero3_enabled -from transformers.dependency_versions_check import dep_version_check -from transformers.modelcard import TrainingSummary -from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model -from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES -from transformers.optimization import Adafactor, get_scheduler -from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_10, is_torch_less_than_1_11 -from transformers.tokenization_utils_base import PreTrainedTokenizerBase -from transformers.trainer_callback import ( +from . import __version__ +from .configuration_utils import PretrainedConfig +from .data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator +from .debug_utils import DebugOption, DebugUnderflowOverflow +from .deepspeed import deepspeed_init, deepspeed_load_checkpoint +from .dependency_versions_check import dep_version_check +from .hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend +from .modelcard import TrainingSummary +from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model +from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES +from .optimization import Adafactor, get_scheduler +from .pytorch_utils import ALL_LAYERNORM_LAYERS +from .tokenization_utils_base import PreTrainedTokenizerBase +from .trainer_callback import ( CallbackHandler, DefaultFlowCallback, PrinterCallback, @@ -85,28 +76,25 @@ from transformers.trainer_callback import ( TrainerControl, TrainerState, ) -from transformers.trainer_pt_utils import ( - DistributedLengthGroupedSampler, - DistributedSamplerWithLoop, +from .trainer_pt_utils import ( DistributedTensorGatherer, IterableDatasetShard, LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, - ShardSampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, + get_model_param_count, get_module_class_from_name, get_parameter_names, nested_concat, nested_detach, nested_numpify, - nested_truncate, nested_xla_mesh_reduce, reissue_pt_warnings, ) -from transformers.trainer_utils import ( +from .trainer_utils import ( PREFIX_CHECKPOINT_DIR, BestRun, EvalLoopOutput, @@ -121,7 +109,6 @@ from transformers.trainer_utils import ( TrainerMemoryTracker, TrainOutput, default_compute_objective, - default_hp_space, denumpify_detensorize, enable_full_determinism, find_executable_batch_size, @@ -132,36 +119,43 @@ from transformers.trainer_utils import ( set_seed, speed_metrics, ) -from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments -from transformers.utils import ( +from .training_args import OptimizerNames, ParallelMode, TrainingArguments +from .utils import ( + ADAPTER_CONFIG_NAME, + ADAPTER_SAFE_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, CONFIG_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME, + PushInProgress, can_return_loss, find_labels, - get_full_repo_name, is_accelerate_available, is_apex_available, + is_bitsandbytes_available, is_datasets_available, is_in_notebook, is_ipex_available, + is_peft_available, + is_safetensors_available, is_sagemaker_dp_enabled, is_sagemaker_mp_enabled, is_torch_compile_available, is_torch_neuroncore_available, is_torch_tpu_available, logging, + strtobool, ) -from transformers.utils.generic import ContextManagers +from .utils.quantization_config import QuantizationMethod -_is_native_cpu_amp_available = is_torch_greater_or_equal_than_1_10 - DEFAULT_CALLBACKS = [DefaultFlowCallback] DEFAULT_PROGRESS_CALLBACK = ProgressCallback if is_in_notebook(): - from transformers.utils.notebook import NotebookProgressCallback + from .utils.notebook import NotebookProgressCallback DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback @@ -174,7 +168,6 @@ if is_datasets_available(): if is_torch_tpu_available(check_device=False): import torch_xla.core.xla_model as xm import torch_xla.debug.metrics as met - import torch_xla.distributed.parallel_loader as pl if is_fairscale_available(): dep_version_check("fairscale") @@ -192,17 +185,31 @@ if is_sagemaker_mp_enabled(): IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10") - from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat + from .trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat else: IS_SAGEMAKER_MP_POST_1_10 = False -skip_first_batches = None -if is_accelerate_available(): - from accelerate import __version__ as accelerate_version +if is_safetensors_available(): + import safetensors.torch - if version.parse(accelerate_version) >= version.parse("0.16"): - from accelerate import skip_first_batches + +if is_peft_available(): + from peft import PeftModel + + +if is_accelerate_available(): + from accelerate import Accelerator, skip_first_batches + from accelerate import __version__ as accelerate_version + from accelerate.utils import DistributedDataParallelKwargs, GradientAccumulationPlugin + + if version.parse(accelerate_version) > version.parse("0.20.3"): + from accelerate.utils import ( + load_fsdp_model, + load_fsdp_optimizer, + save_fsdp_model, + save_fsdp_optimizer, + ) if TYPE_CHECKING: @@ -302,7 +309,8 @@ class Trainer: """ - from transformers.trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state + # Those are used as methods of the Trainer in examples. + from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics, save_state def __init__( self, @@ -317,9 +325,7 @@ class Trainer: callbacks: Optional[List[TrainerCallback]] = None, optimizers: Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None), preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None, - save_prefixencoder: bool = False, ): - self.save_prefixencoder = save_prefixencoder if args is None: output_dir = "tmp_trainer" logger.info(f"No `TrainingArguments` passed, using `output_dir={output_dir}`.") @@ -331,6 +337,8 @@ class Trainer: self.deepspeed = None self.is_in_train = False + self.create_accelerator_and_postprocess() + # memory metrics - must set up as early as possible self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics) self._memory_tracker.start() @@ -371,14 +379,29 @@ class Trainer: else: self.is_model_parallel = False - # At this stage the model is already loaded - if getattr(model, "is_loaded_in_8bit", False): - if getattr(model, "_is_int8_training_enabled", False): + if getattr(model, "hf_device_map", None) is not None: + devices = [device for device in set(model.hf_device_map.values()) if device not in ["cpu", "disk"]] + if len(devices) > 1: + self.is_model_parallel = True + elif len(devices) == 1: + self.is_model_parallel = self.args.device != torch.device(devices[0]) + else: + self.is_model_parallel = False + + # warn users + if self.is_model_parallel: logger.info( - "The model is loaded in 8-bit precision. To train this model you need to add additional modules" + "You have loaded a model on multiple GPUs. `is_model_parallel` attribute will be force-set" + " to `True` to avoid any unexpected behavior such as device placement mismatching." + ) + + # At this stage the model is already loaded + if getattr(model, "is_quantized", False): + if getattr(model, "_is_quantized_training_enabled", False): + logger.info( + "The model is quantized. To train this model you need to add additional modules" " inside the model such as adapters using `peft` library and freeze the model weights. Please" - " check " - " the examples in https://github.com/huggingface/peft for more details." + " check the examples in https://github.com/huggingface/peft for more details." ) else: raise ValueError( @@ -389,7 +412,7 @@ class Trainer: # Setup Sharded DDP training self.sharded_ddp = None if len(args.sharded_ddp) > 0: - if args.deepspeed: + if self.is_deepspeed_enabled: raise ValueError( "Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." ) @@ -397,8 +420,7 @@ class Trainer: raise ValueError( "Using --sharded_ddp xxx together with --fsdp is not possible, deactivate one of those flags." ) - - if args.local_rank == -1: + if args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using sharded DDP only works in distributed training.") elif not is_fairscale_available(): raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") @@ -416,11 +438,11 @@ class Trainer: self.fsdp = None if len(args.fsdp) > 0: - if args.deepspeed: + if self.is_deepspeed_enabled: raise ValueError( "Using --fsdp xxx together with --deepspeed is not possible, deactivate one of those flags." ) - if not args.fsdp_config["xla"] and args.local_rank == -1: + if not args.fsdp_config["xla"] and args.parallel_mode != ParallelMode.DISTRIBUTED: raise ValueError("Using fsdp only works in distributed training.") # dep_version_check("torch>=1.12.0") @@ -440,13 +462,11 @@ class Trainer: self.fsdp = ShardingStrategy.NO_SHARD self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE - if "backward_prefetch" in self.args.fsdp_config and "backward_pos" not in self.backward_prefetch: + if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get( + "backward_prefetch", [] + ): self.backward_prefetch = BackwardPrefetch.BACKWARD_POST - self.forword_prefetch = False - if self.args.fsdp_config.get("forword_prefect", False): - self.forword_prefetch = True - self.limit_all_gathers = False if self.args.fsdp_config.get("limit_all_gathers", False): self.limit_all_gathers = True @@ -462,10 +482,11 @@ class Trainer: self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel - or args.deepspeed + or self.is_deepspeed_enabled or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) or (self.fsdp is not None) + or self.is_fsdp_enabled ): self.place_model_on_device = False @@ -475,7 +496,11 @@ class Trainer: self.eval_dataset = eval_dataset self.tokenizer = tokenizer - if self.place_model_on_device and not getattr(model, "is_loaded_in_8bit", False): + # Bnb Quantized models doesn't support `.to` operation. + if ( + self.place_model_on_device + and not getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES + ): self._move_model_to_device(model, args.device) # Force n_gpu to 1 to avoid DataParallel as MP will manage the GPUs @@ -509,7 +534,7 @@ class Trainer: " `Trainer`. Make sure the lines `import torch_xla.core.xla_model as xm` and" " `model.to(xm.xla_device())` is performed before the optimizer creation in your script." ) - if ((self.sharded_ddp is not None) or args.deepspeed or (self.fsdp is not None)) and ( + if ((self.sharded_ddp is not None) or self.is_deepspeed_enabled or (self.fsdp is not None)) and ( self.optimizer is not None or self.lr_scheduler is not None ): raise RuntimeError( @@ -526,15 +551,10 @@ class Trainer: # Will be set to True by `self._setup_loggers()` on first call to `self.log()`. self._loggers_initialized = False - # Create clone of distant repo and output directory if needed + # Create distant repo and output directory if needed + self.hub_model_id = None if self.args.push_to_hub: - self.init_git_repo(at_init=True) - # In case of pull, we need to make sure every process has the latest. - if is_torch_tpu_available(): - xm.rendezvous("init git repo") - elif args.local_rank != -1: - dist.barrier() - + self.init_hf_repo() if self.args.should_save: os.makedirs(self.args.output_dir, exist_ok=True) @@ -545,7 +565,10 @@ class Trainer: logger.info("max_steps is given, it will override any value given in num_train_epochs") if train_dataset is not None and not has_length(train_dataset) and args.max_steps <= 0: - raise ValueError("train_dataset does not implement __len__, max_steps has to be specified") + raise ValueError( + "The train_dataset does not implement __len__, max_steps has to be specified. " + "The number of steps needs to be known in advance for the learning rate scheduler." + ) if ( train_dataset is not None @@ -584,47 +607,33 @@ class Trainer: "but SageMaker Model Parallelism < 1.10 does not support FP16 in trainer." ) - if args.fp16 or args.bf16: + if (args.fp16 or args.bf16) and self.sharded_ddp is not None: if args.half_precision_backend == "auto": if args.device == torch.device("cpu"): if args.fp16: raise ValueError("Tried to use `fp16` but it is not supported on cpu") - elif _is_native_cpu_amp_available: - args.half_precision_backend = "cpu_amp" else: - raise ValueError("Tried to use cpu amp but native cpu amp is not available") + args.half_precision_backend = "cpu_amp" else: args.half_precision_backend = "cuda_amp" logger.info(f"Using {args.half_precision_backend} half precision backend") self.do_grad_scaling = False - if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()): + if (args.fp16 or args.bf16) and not (self.is_deepspeed_enabled or is_sagemaker_mp_enabled()): # deepspeed and SageMaker Model Parallel manage their own half precision - if args.half_precision_backend == "cuda_amp": - self.use_cuda_amp = True - self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 - # bf16 does not need grad scaling - self.do_grad_scaling = self.amp_dtype == torch.float16 - if self.do_grad_scaling: - if self.sharded_ddp is not None: + if self.sharded_ddp is not None: + if args.half_precision_backend == "cuda_amp": + self.use_cuda_amp = True + self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + # bf16 does not need grad scaling + self.do_grad_scaling = self.amp_dtype == torch.float16 + if self.do_grad_scaling: self.scaler = ShardedGradScaler() - elif self.fsdp is not None: - from torch.distributed.fsdp.sharded_grad_scaler import ( - ShardedGradScaler as FSDPShardedGradScaler, - ) - - self.scaler = FSDPShardedGradScaler() - elif is_torch_tpu_available(): - from torch_xla.amp import GradScaler - - self.scaler = GradScaler() - else: - self.scaler = torch.cuda.amp.GradScaler() - elif args.half_precision_backend == "cpu_amp": - self.use_cpu_amp = True - self.amp_dtype = torch.bfloat16 - else: + elif args.half_precision_backend == "cpu_amp": + self.use_cpu_amp = True + self.amp_dtype = torch.bfloat16 + elif args.half_precision_backend == "apex": if not is_apex_available(): raise ImportError( "Using FP16 with APEX but APEX is not installed, please refer to" @@ -666,8 +675,9 @@ class Trainer: self.can_return_loss = can_return_loss(self.model.__class__) self.control = self.callback_handler.on_init_end(self.args, self.state, self.control) - # Internal variables to keep track of the original batch size + # Internal variables to help with automatic batch size reduction self._train_batch_size = args.train_batch_size + self._created_lr_scheduler = False # very last self._memory_tracker.stop_and_update_metrics() @@ -776,20 +786,6 @@ class Trainer: if self.train_dataset is None or not has_length(self.train_dataset): return None - generator = None - if self.args.world_size <= 1: - generator = torch.Generator() - # for backwards compatibility, we generate a seed here (which is sampled from a generator seeded with - # `args.seed`) if data_seed isn't provided. - # Further on in this method, we default to `args.seed` instead. - if self.args.data_seed is None: - seed = int(torch.empty((), dtype=torch.int64).random_().item()) - else: - seed = self.args.data_seed - generator.manual_seed(seed) - - seed = self.args.data_seed if self.args.data_seed is not None else self.args.seed - # Build the sampler. if self.args.group_by_length: if is_datasets_available() and isinstance(self.train_dataset, datasets.Dataset): @@ -801,47 +797,15 @@ class Trainer: else: lengths = None model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None - if self.args.world_size <= 1: - return LengthGroupedSampler( - self.args.train_batch_size * self.args.gradient_accumulation_steps, - dataset=self.train_dataset, - lengths=lengths, - model_input_name=model_input_name, - generator=generator, - ) - else: - return DistributedLengthGroupedSampler( - self.args.train_batch_size * self.args.gradient_accumulation_steps, - dataset=self.train_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - lengths=lengths, - model_input_name=model_input_name, - seed=seed, - ) + return LengthGroupedSampler( + self.args.train_batch_size * self.args.gradient_accumulation_steps, + dataset=self.train_dataset, + lengths=lengths, + model_input_name=model_input_name, + ) else: - if self.args.world_size <= 1: - return RandomSampler(self.train_dataset, generator=generator) - elif ( - self.args.parallel_mode in [ParallelMode.TPU, ParallelMode.SAGEMAKER_MODEL_PARALLEL] - and not self.args.dataloader_drop_last - ): - # Use a loop for TPUs when drop_last is False to have all batches have the same size. - return DistributedSamplerWithLoop( - self.train_dataset, - batch_size=self.args.per_device_train_batch_size, - num_replicas=self.args.world_size, - rank=self.args.process_index, - seed=seed, - ) - else: - return DistributedSampler( - self.train_dataset, - num_replicas=self.args.world_size, - rank=self.args.process_index, - seed=seed, - ) + return RandomSampler(self.train_dataset) def get_train_dataloader(self) -> DataLoader: """ @@ -862,36 +826,19 @@ class Trainer: else: data_collator = self._get_collator_with_removed_columns(data_collator, description="training") - if isinstance(train_dataset, torch.utils.data.IterableDataset): - if self.args.world_size > 1: - train_dataset = IterableDatasetShard( - train_dataset, - batch_size=self._train_batch_size, - drop_last=self.args.dataloader_drop_last, - num_processes=self.args.world_size, - process_index=self.args.process_index, - ) + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } - return DataLoader( - train_dataset, - batch_size=self._train_batch_size, - collate_fn=data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) + if not isinstance(train_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_train_sampler() + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker - train_sampler = self._get_train_sampler() - - return DataLoader( - train_dataset, - batch_size=self._train_batch_size, - sampler=train_sampler, - collate_fn=data_collator, - drop_last=self.args.dataloader_drop_last, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - worker_init_fn=seed_worker, - ) + return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) def _get_eval_sampler(self, eval_dataset: Dataset) -> Optional[torch.utils.data.Sampler]: # Deprecated code @@ -907,20 +854,13 @@ class Trainer: rank=smp.dp_rank(), batch_size=self.args.per_device_eval_batch_size, ) - elif self.args.local_rank != -1: - return SequentialDistributedSampler(eval_dataset) else: return SequentialSampler(eval_dataset) if self.args.world_size <= 1: return SequentialSampler(eval_dataset) else: - return ShardSampler( - eval_dataset, - batch_size=self.args.per_device_eval_batch_size, - num_processes=self.args.world_size, - process_index=self.args.process_index, - ) + return None def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: """ @@ -943,34 +883,18 @@ class Trainer: else: data_collator = self._get_collator_with_removed_columns(data_collator, description="evaluation") - if isinstance(eval_dataset, torch.utils.data.IterableDataset): - if self.args.world_size > 1: - eval_dataset = IterableDatasetShard( - eval_dataset, - batch_size=self.args.per_device_eval_batch_size, - drop_last=self.args.dataloader_drop_last, - num_processes=self.args.world_size, - process_index=self.args.process_index, - ) - return DataLoader( - eval_dataset, - batch_size=self.args.eval_batch_size, - collate_fn=data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } - eval_sampler = self._get_eval_sampler(eval_dataset) + if not isinstance(eval_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(eval_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last - return DataLoader( - eval_dataset, - sampler=eval_sampler, - batch_size=self.args.eval_batch_size, - collate_fn=data_collator, - drop_last=self.args.dataloader_drop_last, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) + return self.accelerator.prepare(DataLoader(eval_dataset, **dataloader_params)) def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader: """ @@ -990,35 +914,19 @@ class Trainer: else: data_collator = self._get_collator_with_removed_columns(data_collator, description="test") - if isinstance(test_dataset, torch.utils.data.IterableDataset): - if self.args.world_size > 1: - test_dataset = IterableDatasetShard( - test_dataset, - batch_size=self.args.eval_batch_size, - drop_last=self.args.dataloader_drop_last, - num_processes=self.args.world_size, - process_index=self.args.process_index, - ) - return DataLoader( - test_dataset, - batch_size=self.args.eval_batch_size, - collate_fn=data_collator, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } - test_sampler = self._get_eval_sampler(test_dataset) + if not isinstance(test_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_eval_sampler(test_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last # We use the same batch_size as for eval. - return DataLoader( - test_dataset, - sampler=test_sampler, - batch_size=self.args.eval_batch_size, - collate_fn=data_collator, - drop_last=self.args.dataloader_drop_last, - num_workers=self.args.dataloader_num_workers, - pin_memory=self.args.dataloader_pin_memory, - ) + return self.accelerator.prepare(DataLoader(test_dataset, **dataloader_params)) def create_optimizer_and_scheduler(self, num_training_steps: int): """ @@ -1082,10 +990,10 @@ class Trainer: for module in opt_model.modules(): if isinstance(module, nn.Embedding): skipped += sum({p.data_ptr(): p.numel() for p in module.parameters()}.values()) - print(f"skipped {module}: {skipped/2**20}M params") + logger.info(f"skipped {module}: {skipped/2**20}M params") manager.register_module_override(module, "weight", {"optim_bits": 32}) logger.debug(f"bitsandbytes: will optimize {module} in fp32") - print(f"skipped: {skipped/2**20}M params") + logger.info(f"skipped: {skipped/2**20}M params") if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer(self.optimizer) @@ -1120,7 +1028,7 @@ class Trainer: optimizer_cls = Adafactor optimizer_kwargs.update({"scale_parameter": False, "relative_step": False}) elif args.optim == OptimizerNames.ADAMW_HF: - from transformers.optimization import AdamW + from .optimization import AdamW optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) @@ -1147,14 +1055,45 @@ class Trainer: optimizer_kwargs.update(adam_kwargs) except ImportError: raise ValueError("Trainer tried to instantiate apex FusedAdam but apex is not installed!") - elif args.optim == OptimizerNames.ADAMW_BNB: + elif args.optim in [ + OptimizerNames.ADAMW_BNB, + OptimizerNames.ADAMW_8BIT, + OptimizerNames.PAGED_ADAMW, + OptimizerNames.PAGED_ADAMW_8BIT, + OptimizerNames.LION, + OptimizerNames.LION_8BIT, + OptimizerNames.PAGED_LION, + OptimizerNames.PAGED_LION_8BIT, + ]: try: - from bitsandbytes.optim import Adam8bit + from bitsandbytes.optim import AdamW, Lion - optimizer_cls = Adam8bit - optimizer_kwargs.update(adam_kwargs) + is_paged = False + optim_bits = 32 + optimizer_cls = None + additional_optim_kwargs = adam_kwargs + if "paged" in args.optim: + is_paged = True + if "8bit" in args.optim: + optim_bits = 8 + if "adam" in args.optim: + optimizer_cls = AdamW + elif "lion" in args.optim: + optimizer_cls = Lion + additional_optim_kwargs = {"betas": (args.adam_beta1, args.adam_beta2)} + + bnb_kwargs = {"is_paged": is_paged, "optim_bits": optim_bits} + optimizer_kwargs.update(additional_optim_kwargs) + optimizer_kwargs.update(bnb_kwargs) except ImportError: - raise ValueError("Trainer tried to instantiate bnb Adam8bit but bnb is not installed!") + raise ValueError("Trainer tried to instantiate bnb optimizer but bnb is not installed!") + if is_bitsandbytes_available() and version.parse( + importlib.metadata.version("bitsandbytes") + ) < version.parse("0.41.1"): + logger.warning( + "You are using 8-bit optimizers with a version of `bitsandbytes` < 0.41.1. " + "It is recommended to update your version as a major bug has been fixed in 8-bit optimizers." + ) elif args.optim == OptimizerNames.ADAMW_ANYPRECISION: try: from torchdistx.optimizers import AnyPrecisionAdamW @@ -1198,6 +1137,7 @@ class Trainer: num_warmup_steps=self.args.get_warmup_steps(num_training_steps), num_training_steps=num_training_steps, ) + self._created_lr_scheduler = True return self.lr_scheduler def num_examples(self, dataloader: DataLoader) -> int: @@ -1230,6 +1170,8 @@ class Trainer: elif self.hp_search_backend == HPSearchBackend.WANDB: params = trial + # Unfreeze args for hyperparameter search + delattr(self.args, "_frozen") for key, value in params.items(): if not hasattr(self.args, key): logger.warning( @@ -1241,6 +1183,7 @@ class Trainer: # Casting value to the proper type if old_attr is not None: value = type(old_attr)(value) + setattr(self.args, key, value) if self.hp_search_backend == HPSearchBackend.OPTUNA: logger.info(f"Trial: {trial.params}") @@ -1248,12 +1191,21 @@ class Trainer: logger.info(f"SigOpt Assignments: {trial.assignments}") if self.hp_search_backend == HPSearchBackend.WANDB: logger.info(f"W&B Sweep parameters: {trial}") - if self.args.deepspeed: + if self.is_deepspeed_enabled: + if self.args.deepspeed is None: + raise ValueError("For sweeps with deepspeed, `args.deepspeed` must be set") # Rebuild the deepspeed config to reflect the updated training parameters + from accelerate.utils import DeepSpeedPlugin + from transformers.deepspeed import HfTrainerDeepSpeedConfig self.args.hf_deepspeed_config = HfTrainerDeepSpeedConfig(self.args.deepspeed) self.args.hf_deepspeed_config.trainer_config_process(self.args) + self.args.deepspeed_plugin = DeepSpeedPlugin(hf_ds_config=self.args.hf_deepspeed_config) + + # Re-freeze them + setattr(self.args, "_frozen", True) + self.create_accelerator_and_postprocess() def _report_to_hp_search(self, trial: Union["optuna.Trial", Dict[str, Any]], step: int, metrics: Dict[str, float]): if self.hp_search_backend is None or trial is None: @@ -1308,9 +1260,14 @@ class Trainer: example_batch = next(iter(dataloader)) example_batch = self._prepare_inputs(example_batch) try: - jit_model = model.eval() - with ContextManagers([self.autocast_smart_context_manager(cache_enabled=False), torch.no_grad()]): - if version.parse(version.parse(torch.__version__).base_version) >= version.parse("1.14.0"): + jit_model = copy.copy(model) + jit_model.eval() + original_forward = jit_model.__dict__.pop("_original_forward", None) + # remove mixed precision hooks from the model + if original_forward: + jit_model.forward = original_forward + with self.accelerator.autocast(cache_enabled=False), torch.no_grad(): + if version.parse(version.parse(torch.__version__).base_version) >= version.parse("2.0.0"): if isinstance(example_batch, dict): jit_model = torch.jit.trace(jit_model, example_kwarg_inputs=example_batch, strict=False) else: @@ -1362,9 +1319,6 @@ class Trainer: return model def _wrap_model(self, model, training=True, dataloader=None): - if self.args.torch_compile: - model = torch.compile(model, backend=self.args.torch_compile_backend, mode=self.args.torch_compile_mode) - if self.args.use_ipex: dtype = torch.bfloat16 if self.use_cpu_amp else torch.float32 model = self.ipex_optimize_model(model, training, dtype=dtype) @@ -1375,10 +1329,6 @@ class Trainer: return self.model_wrapped return smp.DistributedModel(model, backward_passes_per_step=self.args.gradient_accumulation_steps) - # already initialized its own DDP and AMP - if self.deepspeed: - return self.deepspeed - # train/eval could be run multiple-times - if already wrapped, don't re-wrap it again if unwrap_model(model) is not model: return model @@ -1387,8 +1337,8 @@ class Trainer: if self.use_apex and training: model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level) - # Multi-gpu training (should be after apex fp16 initialization) - if self.args.n_gpu > 1: + # Multi-gpu training (should be after apex fp16 initialization) / 8bit models does not support DDP + if self.args.n_gpu > 1 and not getattr(model, "is_loaded_in_8bit", False): model = nn.DataParallel(model) if self.args.jit_mode_eval: @@ -1420,116 +1370,71 @@ class Trainer: cpu_offload=cpu_offload, ).to(self.args.device) # Distributed training using PyTorch FSDP - elif self.fsdp is not None: - if not self.args.fsdp_config["xla"]: - # PyTorch FSDP! - from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP - from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy - - if FSDPOption.OFFLOAD in self.args.fsdp: - cpu_offload = CPUOffload(offload_params=True) - else: - cpu_offload = CPUOffload(offload_params=False) - - auto_wrap_policy = None - - if FSDPOption.AUTO_WRAP in self.args.fsdp: - if self.args.fsdp_config["fsdp_min_num_params"] > 0: - auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] - ) - elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: - transformer_cls_to_wrap = set() - for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: - transformer_cls = get_module_class_from_name(model, layer_class) - if transformer_cls is None: - raise Exception("Could not find the transformer layer class to wrap in the model.") - else: - transformer_cls_to_wrap.add(transformer_cls) - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - # Transformer layer class to wrap - transformer_layer_cls=transformer_cls_to_wrap, - ) - mixed_precision_policy = None - dtype = None - if self.args.fp16: - dtype = torch.float16 - elif self.args.bf16: - dtype = torch.bfloat16 - if dtype is not None: - mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype) - if type(model) != FSDP: - # XXX: Breaking the self.model convention but I see no way around it for now. - self.model = model = FSDP( - model, - sharding_strategy=self.fsdp, - cpu_offload=cpu_offload, - auto_wrap_policy=auto_wrap_policy, - mixed_precision=mixed_precision_policy, - device_id=self.args.device, - backward_prefetch=self.backward_prefetch, - forward_prefetch=self.forword_prefetch, - limit_all_gathers=self.limit_all_gathers, - ) - else: - try: - from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP - from torch_xla.distributed.fsdp import checkpoint_module - from torch_xla.distributed.fsdp.wrap import ( - size_based_auto_wrap_policy, - transformer_auto_wrap_policy, - ) - except ImportError: - raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") - auto_wrap_policy = None - auto_wrapper_callable = None - if self.args.fsdp_config["fsdp_min_num_params"] > 0: - auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] - ) - elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: - transformer_cls_to_wrap = set() - for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]: - transformer_cls = get_module_class_from_name(model, layer_class) - if transformer_cls is None: - raise Exception("Could not find the transformer layer class to wrap in the model.") - else: - transformer_cls_to_wrap.add(transformer_cls) - auto_wrap_policy = functools.partial( - transformer_auto_wrap_policy, - # Transformer layer class to wrap - transformer_layer_cls=transformer_cls_to_wrap, - ) - fsdp_kwargs = self.args.xla_fsdp_config - if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: - # Apply gradient checkpointing to auto-wrapped sub-modules if specified - def auto_wrapper_callable(m, *args, **kwargs): - return FSDP(checkpoint_module(m), *args, **kwargs) - - # Wrap the base model with an outer FSDP wrapper - self.model = model = FSDP( - model, - auto_wrap_policy=auto_wrap_policy, - auto_wrapper_callable=auto_wrapper_callable, - **fsdp_kwargs, + elif self.fsdp is not None and self.args.fsdp_config["xla"]: + try: + from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP + from torch_xla.distributed.fsdp import checkpoint_module + from torch_xla.distributed.fsdp.wrap import ( + size_based_auto_wrap_policy, + transformer_auto_wrap_policy, ) + except ImportError: + raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.") + auto_wrap_policy = None + auto_wrapper_callable = None + default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) + fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get( + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + ) - # Patch `xm.optimizer_step` should not reduce gradients in this case, - # as FSDP does not need gradient reduction over sharded parameters. - def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): - loss = optimizer.step(**optimizer_args) - if barrier: - xm.mark_step() - return loss + if self.args.fsdp_config["min_num_params"] > 0: + auto_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"] + ) + elif fsdp_transformer_layer_cls_to_wrap is not None: + transformer_cls_to_wrap = set() + for layer_class in fsdp_transformer_layer_cls_to_wrap: + transformer_cls = get_module_class_from_name(model, layer_class) + if transformer_cls is None: + raise Exception("Could not find the transformer layer class to wrap in the model.") + else: + transformer_cls_to_wrap.add(transformer_cls) - xm.optimizer_step = patched_optimizer_step + auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + # Transformer layer class to wrap + transformer_layer_cls=transformer_cls_to_wrap, + ) + fsdp_kwargs = self.args.xla_fsdp_config + if self.args.fsdp_config["xla_fsdp_grad_ckpt"]: + # Apply gradient checkpointing to auto-wrapped sub-modules if specified + def auto_wrapper_callable(m, *args, **kwargs): + return FSDP(checkpoint_module(m), *args, **kwargs) + + # Wrap the base model with an outer FSDP wrapper + self.model = model = FSDP( + model, + auto_wrap_policy=auto_wrap_policy, + auto_wrapper_callable=auto_wrapper_callable, + **fsdp_kwargs, + ) + + # Patch `xm.optimizer_step` should not reduce gradients in this case, + # as FSDP does not need gradient reduction over sharded parameters. + def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}): + loss = optimizer.step(**optimizer_args) + if barrier: + xm.mark_step() + return loss + + xm.optimizer_step = patched_optimizer_step elif is_sagemaker_dp_enabled(): model = nn.parallel.DistributedDataParallel( model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))] ) - elif self.args.local_rank != -1: + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: + if is_torch_neuroncore_available(): + return model kwargs = {} if self.args.ddp_find_unused_parameters is not None: kwargs["find_unused_parameters"] = self.args.ddp_find_unused_parameters @@ -1542,14 +1447,11 @@ class Trainer: if self.args.ddp_bucket_cap_mb is not None: kwargs["bucket_cap_mb"] = self.args.ddp_bucket_cap_mb - if is_torch_neuroncore_available(): - return model - model = nn.parallel.DistributedDataParallel( - model, - device_ids=[self.args.local_rank] if self.args._n_gpu != 0 else None, - output_device=self.args.local_rank if self.args._n_gpu != 0 else None, - **kwargs, - ) + + if self.args.ddp_broadcast_buffers is not None: + kwargs["broadcast_buffers"] = self.args.ddp_broadcast_buffers + + self.accelerator.ddp_handler = DistributedDataParallelKwargs(**kwargs) return model @@ -1573,7 +1475,7 @@ class Trainer: ignore_keys_for_eval (`List[str]`, *optional*) A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions for evaluation during the training. - kwargs: + kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments used to hide deprecated arguments """ if resume_from_checkpoint is False: @@ -1620,7 +1522,12 @@ class Trainer: if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and args.deepspeed is None: + if ( + resume_from_checkpoint is not None + and not is_sagemaker_mp_enabled() + and not self.is_deepspeed_enabled + and not self.is_fsdp_enabled + ): self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -1632,17 +1539,32 @@ class Trainer: inner_training_loop = find_executable_batch_size( self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size ) - return inner_training_loop( - args=args, - resume_from_checkpoint=resume_from_checkpoint, - trial=trial, - ignore_keys_for_eval=ignore_keys_for_eval, - ) + if args.push_to_hub: + try: + # Disable progress bars when uploading models during checkpoints to avoid polluting stdout + hf_hub_utils.disable_progress_bars() + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) + finally: + hf_hub_utils.enable_progress_bars() + else: + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) def _inner_training_loop( self, batch_size=None, args=None, resume_from_checkpoint=None, trial=None, ignore_keys_for_eval=None ): + self.accelerator.free_memory() self._train_batch_size = batch_size + logger.debug(f"Currently training with a batch size of: {self._train_batch_size}") # Data loader and number of training steps train_dataloader = self.get_train_dataloader() @@ -1650,7 +1572,7 @@ class Trainer: # number of training epochs: num_train_epochs # number of training steps per epoch: num_update_steps_per_epoch # total number of training steps to execute: max_steps - total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size + total_train_batch_size = self._train_batch_size * args.gradient_accumulation_steps * args.world_size len_dataloader = None if has_length(train_dataloader): @@ -1699,37 +1621,87 @@ class Trainer: and self.sharded_ddp != ShardedDDPOption.SIMPLE or is_sagemaker_mp_enabled() or self.fsdp is not None + or self.is_fsdp_enabled ) - if args.deepspeed: - deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - self, num_training_steps=max_steps, resume_from_checkpoint=resume_from_checkpoint - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - elif not delay_optimizer_creation: + + # We need to reset the scheduler, as its parameters may be different on subsequent calls + if self._created_lr_scheduler: + self.lr_scheduler = None + self._created_lr_scheduler = False + + if self.is_deepspeed_enabled: + self.optimizer, self.lr_scheduler = deepspeed_init(self, num_training_steps=max_steps) + + if not delay_optimizer_creation: self.create_optimizer_and_scheduler(num_training_steps=max_steps) self.state = TrainerState() self.state.is_hyper_param_search = trial is not None + # Compute absolute values for logging, eval, and save if given as ratio + if args.logging_steps is not None: + if args.logging_steps < 1: + self.state.logging_steps = math.ceil(max_steps * args.logging_steps) + else: + self.state.logging_steps = args.logging_steps + if args.eval_steps is not None: + if args.eval_steps < 1: + self.state.eval_steps = math.ceil(max_steps * args.eval_steps) + else: + self.state.eval_steps = args.eval_steps + if args.save_steps is not None: + if args.save_steps < 1: + self.state.save_steps = math.ceil(max_steps * args.save_steps) + else: + self.state.save_steps = args.save_steps + # Activate gradient checkpointing if needed if args.gradient_checkpointing: self.model.gradient_checkpointing_enable() model = self._wrap_model(self.model_wrapped) - if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) + # as the model is wrapped, don't use `accelerator.prepare` + # this is for unhandled cases such as + # Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX + use_accelerator_prepare = True if model is self.model else False + + if delay_optimizer_creation: + if use_accelerator_prepare: + self.model = self.accelerator.prepare(self.model) + self.create_optimizer_and_scheduler(num_training_steps=max_steps) + + # prepare using `accelerator` prepare + if use_accelerator_prepare: + self.model.train() + if hasattr(self.lr_scheduler, "step"): + if self.use_apex: + model = self.accelerator.prepare(self.model) + else: + model, self.optimizer = self.accelerator.prepare(self.model, self.optimizer) + else: + # to handle cases wherein we pass "DummyScheduler" such as when it is specified in DeepSpeed config. + model, self.optimizer, self.lr_scheduler = self.accelerator.prepare( + self.model, self.optimizer, self.lr_scheduler + ) + + if self.is_fsdp_enabled: + self.model = model + # for the rest of this function `model` is the outside model, whether it was wrapped or not if model is not self.model: self.model_wrapped = model - if delay_optimizer_creation: - self.create_optimizer_and_scheduler(num_training_steps=max_steps) + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + + # deepspeed ckpt loading + if resume_from_checkpoint is not None and self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, resume_from_checkpoint) # Check if saved optimizer or scheduler states exist self._load_optimizer_and_scheduler(resume_from_checkpoint) @@ -1740,15 +1712,15 @@ class Trainer: # Train! logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples}") - logger.info(f" Num Epochs = {num_train_epochs}") - logger.info(f" Instantaneous batch size per device = {args.per_device_train_batch_size}") - logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") + logger.info(f" Num examples = {num_examples:,}") + logger.info(f" Num Epochs = {num_train_epochs:,}") + logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + if self.args.per_device_train_batch_size != self._train_batch_size: + logger.info(f" Training with DataParallel so batch size has been adjusted to: {self._train_batch_size:,}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size:,}") logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") - logger.info(f" Total optimization steps = {max_steps}") - logger.info( - f" Number of trainable parameters = {sum(p.numel() for p in model.parameters() if p.requires_grad)}" - ) + logger.info(f" Total optimization steps = {max_steps:,}") + logger.info(f" Number of trainable parameters = {get_model_param_count(model, trainable_only=True):,}") self.state.epoch = 0 start_time = time.time() @@ -1772,22 +1744,10 @@ class Trainer: logger.info(f" Continuing training from epoch {epochs_trained}") logger.info(f" Continuing training from global step {self.state.global_step}") if not args.ignore_data_skip: - if skip_first_batches is None: - logger.info( - f" Will skip the first {epochs_trained} epochs then the first" - f" {steps_trained_in_current_epoch} batches in the first epoch. If this takes a lot of time," - " you can install the latest version of Accelerate with `pip install -U accelerate`.You can" - " also add the `--ignore_data_skip` flag to your launch command, but you will resume the" - " training on data already seen by your model." - ) - else: - logger.info( - f" Will skip the first {epochs_trained} epochs then the first" - f" {steps_trained_in_current_epoch} batches in the first epoch." - ) - if self.is_local_process_zero() and not args.disable_tqdm and skip_first_batches is None: - steps_trained_progress_bar = tqdm(total=steps_trained_in_current_epoch) - steps_trained_progress_bar.set_description("Skipping the first batches") + logger.info( + f" Will skip the first {epochs_trained} epochs then the first" + f" {steps_trained_in_current_epoch} batches in the first epoch." + ) # Update the references self.callback_handler.model = self.model @@ -1822,31 +1782,12 @@ class Trainer: # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - is_random_sampler = hasattr(train_dataloader, "sampler") and isinstance( - train_dataloader.sampler, RandomSampler - ) - if is_torch_less_than_1_11 or not is_random_sampler: - # We just need to begin an iteration to create the randomization of the sampler. - # That was before PyTorch 1.11 however... - for _ in train_dataloader: - break - else: - # Otherwise we need to call the whooooole sampler cause there is some random operation added - # AT THE VERY END! - _ = list(train_dataloader.sampler) + for _ in train_dataloader: + break total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): - if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler): - train_dataloader.sampler.set_epoch(epoch) - elif hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDatasetShard): - train_dataloader.dataset.set_epoch(epoch) - - if is_torch_tpu_available(): - parallel_loader = pl.ParallelLoader(train_dataloader, [args.device]).per_device_loader(args.device) - epoch_iterator = parallel_loader - else: - epoch_iterator = train_dataloader + epoch_iterator = train_dataloader # Reset the past mems state at the beginning of each epoch if necessary. if args.past_index >= 0: @@ -1864,7 +1805,7 @@ class Trainer: rng_to_sync = False steps_skipped = 0 - if skip_first_batches is not None and steps_trained_in_current_epoch > 0: + if steps_trained_in_current_epoch > 0: epoch_iterator = skip_first_batches(epoch_iterator, steps_trained_in_current_epoch) steps_skipped = steps_trained_in_current_epoch steps_trained_in_current_epoch = 0 @@ -1892,15 +1833,7 @@ class Trainer: if step % args.gradient_accumulation_steps == 0: self.control = self.callback_handler.on_step_begin(args, self.state, self.control) - if ( - (total_batched_samples % args.gradient_accumulation_steps != 0) - and args.local_rank != -1 - and args._no_sync_in_gradient_accumulation - ): - # Avoid unnecessary DDP synchronization since there will be no backward pass on this example. - with model.no_sync(): - tr_loss_step = self.training_step(model, inputs) - else: + with self.accelerator.accumulate(model): tr_loss_step = self.training_step(model, inputs) if ( @@ -1915,17 +1848,25 @@ class Trainer: self.current_flos += float(self.floating_point_ops(inputs)) - # Optimizer step for deepspeed must be called on every step regardless of the value of gradient_accumulation_steps - if self.deepspeed: - self.deepspeed.step() + is_last_step_and_steps_less_than_grad_acc = ( + steps_in_epoch <= args.gradient_accumulation_steps and (step + 1) == steps_in_epoch + ) - if total_batched_samples % args.gradient_accumulation_steps == 0 or ( + if ( + total_batched_samples % args.gradient_accumulation_steps == 0 + or # last step in epoch but step is always smaller than gradient_accumulation_steps - steps_in_epoch <= args.gradient_accumulation_steps - and (step + 1) == steps_in_epoch + is_last_step_and_steps_less_than_grad_acc ): + # the `or` condition of `is_last_step_and_steps_less_than_grad_acc` is not covered + # in accelerate. So, explicitly enable sync gradients to True in that case. + if is_last_step_and_steps_less_than_grad_acc or ( + version.parse(accelerate_version) <= version.parse("0.20.3") + ): + self.accelerator.gradient_state._set_sync_gradients(True) + # Gradient clipping - if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: + if args.max_grad_norm is not None and args.max_grad_norm > 0: # deepspeed does its own clipping if self.do_grad_scaling: @@ -1944,23 +1885,27 @@ class Trainer: elif hasattr(model, "clip_grad_norm_"): # Some models (like FullyShardedDDP) have a specific way to do gradient clipping model.clip_grad_norm_(args.max_grad_norm) - else: + elif self.use_apex: # Revert to normal clipping otherwise, handling Apex or full precision nn.utils.clip_grad_norm_( - amp.master_params(self.optimizer) if self.use_apex else model.parameters(), + amp.master_params(self.optimizer), + args.max_grad_norm, + ) + else: + self.accelerator.clip_grad_norm_( + model.parameters(), args.max_grad_norm, ) # Optimizer step optimizer_was_run = True - if self.deepspeed: - pass # called outside the loop - elif is_torch_tpu_available(): + if is_torch_tpu_available(): if self.do_grad_scaling: self.scaler.step(self.optimizer) self.scaler.update() else: - xm.optimizer_step(self.optimizer) + # tpu-comment: accelerate wrapped optimizers call xm.optimizer_step + self.optimizer.step() elif self.do_grad_scaling: scale_before = self.scaler.get_scale() self.scaler.step(self.optimizer) @@ -1969,9 +1914,12 @@ class Trainer: optimizer_was_run = scale_before <= scale_after else: self.optimizer.step() + optimizer_was_run = not self.accelerator.optimizer_step_was_skipped - if optimizer_was_run and not self.deepspeed: - self.lr_scheduler.step() + if optimizer_was_run: + # Delay optimizer scheduling until metrics are generated + if not isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + self.lr_scheduler.step() model.zero_grad() self.state.global_step += 1 @@ -2016,7 +1964,7 @@ class Trainer: # Wait for everyone to get here so we are sur the model has been saved by process 0. if is_torch_tpu_available(): xm.rendezvous("load_best_model_at_end") - elif args.local_rank != -1: + elif args.parallel_mode == ParallelMode.DISTRIBUTED: dist.barrier() elif is_sagemaker_mp_enabled(): smp.barrier() @@ -2044,12 +1992,15 @@ class Trainer: # Delete the last checkpoint when save_total_limit=1 if it's different from the best checkpoint and process allowed to save. if self.args.should_save and self.state.best_model_checkpoint is not None and self.args.save_total_limit == 1: for checkpoint in checkpoints_sorted: - if checkpoint != self.state.best_model_checkpoint: + if not os.path.samefile(checkpoint, self.state.best_model_checkpoint): logger.info(f"Deleting older checkpoint [{checkpoint}] due to args.save_total_limit") shutil.rmtree(checkpoint) self.control = self.callback_handler.on_train_end(args, self.state, self.control) + # Wait for the checkpoint to be uploaded. + self._finish_current_push() + return TrainOutput(self.state.global_step, train_loss, metrics) def _get_output_dir(self, trial): @@ -2076,15 +2027,31 @@ class Trainer: if model is None: model = self.model - if not os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)) and not os.path.isfile( - os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + config_file = os.path.join(resume_from_checkpoint, CONFIG_NAME) + adapter_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_WEIGHTS_NAME) + adapter_safe_weights_file = os.path.join(resume_from_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) + weights_file = os.path.join(resume_from_checkpoint, WEIGHTS_NAME) + weights_index_file = os.path.join(resume_from_checkpoint, WEIGHTS_INDEX_NAME) + safe_weights_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_NAME) + safe_weights_index_file = os.path.join(resume_from_checkpoint, SAFE_WEIGHTS_INDEX_NAME) + + if not any( + os.path.isfile(f) + for f in [ + weights_file, + safe_weights_file, + weights_index_file, + safe_weights_index_file, + adapter_weights_file, + adapter_safe_weights_file, + ] ): raise ValueError(f"Can't find a valid checkpoint at {resume_from_checkpoint}") logger.info(f"Loading model from {resume_from_checkpoint}.") - if os.path.isfile(os.path.join(resume_from_checkpoint, CONFIG_NAME)): - config = PretrainedConfig.from_json_file(os.path.join(resume_from_checkpoint, CONFIG_NAME)) + if os.path.isfile(config_file): + config = PretrainedConfig.from_json_file(config_file) checkpoint_version = config.transformers_version if checkpoint_version is not None and checkpoint_version != __version__: logger.warning( @@ -2093,7 +2060,7 @@ class Trainer: "yield to errors or unwanted behaviors." ) - if os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)): + if os.path.isfile(weights_file) or os.path.isfile(safe_weights_file): # If the model is on the GPU, it still works! if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(resume_from_checkpoint, "user_content.pt")): @@ -2109,74 +2076,123 @@ class Trainer: logger.warning( "Enabling FP16 and loading from smp < 1.10 checkpoint together is not suppported." ) - state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + state_dict = torch.load(weights_file, map_location="cpu") # Required for smp to not auto-translate state_dict from hf to smp (is already smp). state_dict["_smp_is_partial"] = False load_result = model.load_state_dict(state_dict, strict=True) # release memory del state_dict + elif self.is_fsdp_enabled: + load_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, model, resume_from_checkpoint) else: # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(os.path.join(resume_from_checkpoint, WEIGHTS_NAME), map_location="cpu") + if self.args.save_safetensors and os.path.isfile(safe_weights_file): + state_dict = safetensors.torch.load_file(safe_weights_file, device="cpu") + else: + state_dict = torch.load(weights_file, map_location="cpu") + # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) # release memory del state_dict self._issue_warnings_after_load(load_result) + + # Load adapters following PR # 24096 + elif is_peft_available() and isinstance(model, PeftModel): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + if os.path.exists(resume_from_checkpoint): + model.load_adapter(resume_from_checkpoint, model.active_adapter, is_trainable=True) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " + "Check some examples here: https://github.com/huggingface/peft/issues/96" + ) + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") else: # We load the sharded checkpoint - load_result = load_sharded_checkpoint(model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled()) + load_result = load_sharded_checkpoint( + model, resume_from_checkpoint, strict=is_sagemaker_mp_enabled(), prefer_safe=self.args.save_safetensors + ) if not is_sagemaker_mp_enabled(): self._issue_warnings_after_load(load_result) def _load_best_model(self): logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") best_model_path = os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME) - model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if os.path.exists(best_model_path): - if self.deepspeed: - if self.model_wrapped is not None: - # this removes the pre-hooks from the previous engine - self.model_wrapped.destroy() - self.model_wrapped = None + best_safe_model_path = os.path.join(self.state.best_model_checkpoint, SAFE_WEIGHTS_NAME) + best_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_WEIGHTS_NAME) + best_safe_adapter_model_path = os.path.join(self.state.best_model_checkpoint, ADAPTER_SAFE_WEIGHTS_NAME) - # temp hack until Deepspeed fixes the problem with resume from an existing engine that did some stepping - deepspeed_engine, optimizer, lr_scheduler = deepspeed_init( - self, - num_training_steps=self.args.max_steps, - resume_from_checkpoint=self.state.best_model_checkpoint, - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - else: - if is_sagemaker_mp_enabled(): - if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): - # If the 'user_content.pt' file exists, load with the new smp api. - # Checkpoint must have been saved with the new smp api. - smp.resume_from_checkpoint( - path=self.state.best_model_checkpoint, - tag=WEIGHTS_NAME, - partial=False, - load_optimizer=False, - ) + model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.is_deepspeed_enabled: + deepspeed_load_checkpoint(self.model_wrapped, self.state.best_model_checkpoint) + elif ( + os.path.exists(best_model_path) + or os.path.exists(best_safe_model_path) + or os.path.exists(best_adapter_model_path) + or os.path.exists(best_safe_adapter_model_path) + ): + has_been_loaded = True + if is_sagemaker_mp_enabled(): + if os.path.isfile(os.path.join(self.state.best_model_checkpoint, "user_content.pt")): + # If the 'user_content.pt' file exists, load with the new smp api. + # Checkpoint must have been saved with the new smp api. + smp.resume_from_checkpoint( + path=self.state.best_model_checkpoint, + tag=WEIGHTS_NAME, + partial=False, + load_optimizer=False, + ) + else: + # If the 'user_content.pt' file does NOT exist, load with the old smp api. + # Checkpoint must have been saved with the old smp api. + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") else: - # If the 'user_content.pt' file does NOT exist, load with the old smp api. - # Checkpoint must have been saved with the old smp api. state_dict = torch.load(best_model_path, map_location="cpu") - state_dict["_smp_is_partial"] = False - load_result = model.load_state_dict(state_dict, strict=True) + + state_dict["_smp_is_partial"] = False + load_result = model.load_state_dict(state_dict, strict=True) + elif self.is_fsdp_enabled: + load_result = load_fsdp_model( + self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint + ) + else: + if is_peft_available() and isinstance(model, PeftModel): + # If train a model using PEFT & LoRA, assume that adapter have been saved properly. + if hasattr(model, "active_adapter") and hasattr(model, "load_adapter"): + if os.path.exists(best_adapter_model_path) or os.path.exists(best_safe_adapter_model_path): + model.load_adapter(self.state.best_model_checkpoint, model.active_adapter) + # Load_adapter has no return value present, modify it when appropriate. + from torch.nn.modules.module import _IncompatibleKeys + + load_result = _IncompatibleKeys([], []) + else: + logger.warning( + "The intermediate checkpoints of PEFT may not be saved correctly, " + f"consider using a custom callback to save {ADAPTER_WEIGHTS_NAME} in corresponding saving folders. " + "Check some examples here: https://github.com/huggingface/peft/issues/96" + ) + has_been_loaded = False + else: + logger.warning("Could not load adapter model, make sure to have `peft>=0.3.0` installed") + has_been_loaded = False else: # We load the model state dict on the CPU to avoid an OOM error. - state_dict = torch.load(best_model_path, map_location="cpu") + if self.args.save_safetensors and os.path.isfile(best_safe_model_path): + state_dict = safetensors.torch.load_file(best_safe_model_path, device="cpu") + else: + state_dict = torch.load(best_model_path, map_location="cpu") + # If the model is on the GPU, it still works! # workaround for FSDP bug https://github.com/pytorch/pytorch/issues/82963 # which takes *args instead of **kwargs load_result = model.load_state_dict(state_dict, False) - if not is_sagemaker_mp_enabled(): + if not is_sagemaker_mp_enabled() and has_been_loaded: self._issue_warnings_after_load(load_result) elif os.path.exists(os.path.join(self.state.best_model_checkpoint, WEIGHTS_INDEX_NAME)): load_result = load_sharded_checkpoint( @@ -2228,16 +2244,25 @@ class Trainer: metrics = None if self.control.should_evaluate: if isinstance(self.eval_dataset, dict): + metrics = {} for eval_dataset_name, eval_dataset in self.eval_dataset.items(): - metrics = self.evaluate( + dataset_metrics = self.evaluate( eval_dataset=eval_dataset, ignore_keys=ignore_keys_for_eval, metric_key_prefix=f"eval_{eval_dataset_name}", ) + metrics.update(dataset_metrics) else: metrics = self.evaluate(ignore_keys=ignore_keys_for_eval) self._report_to_hp_search(trial, self.state.global_step, metrics) + # Run delayed LR scheduler now that metrics are populated + if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + metric_to_check = self.args.metric_for_best_model + if not metric_to_check.startswith("eval_"): + metric_to_check = f"eval_{metric_to_check}" + self.lr_scheduler.step(metrics[metric_to_check]) + if self.control.should_save: self._save_checkpoint(model, trial, metrics=metrics) self.control = self.callback_handler.on_save(self.args, self.state, self.control) @@ -2270,11 +2295,11 @@ class Trainer: np.random.set_state(checkpoint_rng_state["numpy"]) torch.random.set_rng_state(checkpoint_rng_state["cpu"]) if torch.cuda.is_available(): - if self.args.local_rank != -1: - torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: + torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) else: try: - torch.cuda.random.set_rng_state_all(checkpoint_rng_state["cuda"]) + torch.cuda.random.set_rng_state(checkpoint_rng_state["cuda"]) except Exception as e: logger.info( f"Didn't manage to set back the RNG states of the GPU because of the following error:\n {e}" @@ -2297,15 +2322,27 @@ class Trainer: run_dir = self._get_output_dir(trial=trial) output_dir = os.path.join(run_dir, checkpoint_folder) self.save_model(output_dir, _internal_call=True) - if self.deepspeed: + if self.is_deepspeed_enabled: # under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed # config `stage3_gather_16bit_weights_on_model_save` is True - self.deepspeed.save_checkpoint(output_dir) + self.model_wrapped.save_checkpoint(output_dir) # Save optimizer and scheduler if self.sharded_ddp == ShardedDDPOption.SIMPLE: self.optimizer.consolidate_state_dict() + if self.fsdp or self.is_fsdp_enabled: + if self.is_fsdp_enabled: + save_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, self.accelerator, self.optimizer, self.model, output_dir + ) + else: + # FSDP has a different interface for saving optimizer states. + # Needs to be called on all ranks to gather all states. + # full_optim_state_dict will be deprecated after Pytorch 2.2! + full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer) + torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME)) + if is_torch_tpu_available(): xm.rendezvous("saving_optimizer_states") xm.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) @@ -2328,9 +2365,10 @@ class Trainer: reissue_pt_warnings(caught_warnings) if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) - elif self.args.should_save and not self.deepspeed: + elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled): # deepspeed.save_checkpoint above saves model/optim/sched torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME)) + with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) @@ -2364,7 +2402,7 @@ class Trainer: "cpu": torch.random.get_rng_state(), } if torch.cuda.is_available(): - if self.args.local_rank == -1: + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: # In non distributed, we save the global CUDA RNG state (will take care of DataParallel) rng_states["cuda"] = torch.cuda.random.get_rng_state_all() else: @@ -2394,7 +2432,7 @@ class Trainer: if checkpoint is None: return - if self.deepspeed: + if self.is_deepspeed_enabled: # deepspeed loads optimizer/lr_scheduler together with the model in deepspeed_init return @@ -2418,7 +2456,6 @@ class Trainer: self.optimizer.load_state_dict(optimizer_state) self.lr_scheduler.load_state_dict(lr_scheduler_state) else: - map_location = "cpu" if is_sagemaker_mp_enabled() else self.args.device if is_sagemaker_mp_enabled(): if os.path.isfile(os.path.join(checkpoint, "user_content.pt")): # Optimizer checkpoint was saved with smp >= 1.10 @@ -2437,9 +2474,31 @@ class Trainer: self.model_wrapped.register_post_step_hook(opt_load_hook) else: - self.optimizer.load_state_dict( - torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) - ) + # We use the CPU when training on one GPU to avoid OOM for GPU RAM when training big models. + # In distributed training however, we load directly on each GPU and risk the GPU OOM as it's more + # likely to get OOM on CPU (since we load num_gpu times the optimizer state + map_location = self.args.device if self.args.world_size > 1 else "cpu" + if self.fsdp or self.is_fsdp_enabled: + if self.is_fsdp_enabled: + load_fsdp_optimizer( + self.accelerator.state.fsdp_plugin, + self.accelerator, + self.optimizer, + self.model, + checkpoint, + ) + else: + full_osd = None + # In FSDP, we need to load the full optimizer state dict on rank 0 and then shard it + if self.args.process_index == 0: + full_osd = torch.load(os.path.join(checkpoint, OPTIMIZER_NAME)) + # call scatter_full_optim_state_dict on all ranks + sharded_osd = self.model.__class__.scatter_full_optim_state_dict(full_osd, self.model) + self.optimizer.load_state_dict(sharded_osd) + else: + self.optimizer.load_state_dict( + torch.load(os.path.join(checkpoint, OPTIMIZER_NAME), map_location=map_location) + ) with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) @@ -2503,41 +2562,20 @@ class Trainer: """ if backend is None: backend = default_hp_search_backend() - if backend is None: - raise RuntimeError( - "At least one of optuna or ray should be installed. " - "To install optuna run `pip install optuna`. " - "To install ray run `pip install ray[tune]`. " - "To install sigopt run `pip install sigopt`." - ) backend = HPSearchBackend(backend) - if backend == HPSearchBackend.OPTUNA and not is_optuna_available(): - raise RuntimeError("You picked the optuna backend, but it is not installed. Use `pip install optuna`.") - if backend == HPSearchBackend.RAY and not is_ray_tune_available(): - raise RuntimeError( - "You picked the Ray Tune backend, but it is not installed. Use `pip install 'ray[tune]'`." - ) - if backend == HPSearchBackend.SIGOPT and not is_sigopt_available(): - raise RuntimeError("You picked the sigopt backend, but it is not installed. Use `pip install sigopt`.") - if backend == HPSearchBackend.WANDB and not is_wandb_available(): - raise RuntimeError("You picked the wandb backend, but it is not installed. Use `pip install wandb`.") + backend_obj = ALL_HYPERPARAMETER_SEARCH_BACKENDS[backend]() + backend_obj.ensure_available() self.hp_search_backend = backend if self.model_init is None: raise RuntimeError( "To use hyperparameter search, you need to pass your model through a model_init function." ) - self.hp_space = default_hp_space[backend] if hp_space is None else hp_space + self.hp_space = backend_obj.default_hp_space if hp_space is None else hp_space self.hp_name = hp_name self.compute_objective = default_compute_objective if compute_objective is None else compute_objective - backend_dict = { - HPSearchBackend.OPTUNA: run_hp_search_optuna, - HPSearchBackend.RAY: run_hp_search_ray, - HPSearchBackend.SIGOPT: run_hp_search_sigopt, - HPSearchBackend.WANDB: run_hp_search_wandb, - } - best_run = backend_dict[backend](self, n_trials, direction, **kwargs) + best_run = backend_obj.run(self, n_trials, direction, **kwargs) self.hp_search_backend = None return best_run @@ -2569,11 +2607,11 @@ class Trainer: return type(data)(self._prepare_input(v) for v in data) elif isinstance(data, torch.Tensor): kwargs = {"device": self.args.device} - if self.deepspeed and (torch.is_floating_point(data) or torch.is_complex(data)): + if self.is_deepspeed_enabled and (torch.is_floating_point(data) or torch.is_complex(data)): # NLP models inputs are int/uint and those get adjusted to the right dtype of the # embedding. Other models such as wav2vec2's inputs are already float and thus # may need special handling to match the dtypes of the model - kwargs.update({"dtype": self.args.hf_deepspeed_config.dtype()}) + kwargs.update({"dtype": self.accelerator.state.deepspeed_plugin.hf_ds_config.dtype()}) return data.to(**kwargs) return data @@ -2605,16 +2643,13 @@ class Trainer: arguments, depending on the situation. """ if self.use_cuda_amp or self.use_cpu_amp: - if is_torch_greater_or_equal_than_1_10: - ctx_manager = ( - torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) - if self.use_cpu_amp - else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) - ) - else: - ctx_manager = torch.cuda.amp.autocast() + ctx_manager = ( + torch.cpu.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + if self.use_cpu_amp + else torch.cuda.amp.autocast(cache_enabled=cache_enabled, dtype=self.amp_dtype) + ) else: - ctx_manager = contextlib.nullcontext() if sys.version_info >= (3, 7) else contextlib.suppress() + ctx_manager = contextlib.nullcontext() return ctx_manager @@ -2649,22 +2684,15 @@ class Trainer: if self.args.n_gpu > 1: loss = loss.mean() # mean() to average on multi-gpu parallel training - if self.args.gradient_accumulation_steps > 1 and not self.deepspeed: - # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` - loss = loss / self.args.gradient_accumulation_steps - if self.do_grad_scaling: self.scaler.scale(loss).backward() elif self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() - elif self.deepspeed: - # loss gets scaled under gradient_accumulation_steps in deepspeed - loss = self.deepspeed.backward(loss) else: - loss.backward() + self.accelerator.backward(loss) - return loss.detach() + return loss.detach() / self.args.gradient_accumulation_steps def compute_loss(self, model, inputs, return_outputs=False): """ @@ -2683,7 +2711,11 @@ class Trainer: self._past = outputs[self.args.past_index] if labels is not None: - if unwrap_model(model)._get_name() in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): + if is_peft_available() and isinstance(model, PeftModel): + model_name = unwrap_model(model.base_model)._get_name() + else: + model_name = unwrap_model(model)._get_name() + if model_name in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES.values(): loss = self.label_smoother(outputs, labels, shift_labels=True) else: loss = self.label_smoother(outputs, labels) @@ -2742,37 +2774,34 @@ class Trainer: ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp or self.fsdp is not None + or self.is_fsdp_enabled ): - state_dict = self.model.state_dict() - + state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {} if self.args.should_save: self._save(output_dir, state_dict=state_dict) - elif self.deepspeed: - # this takes care of everything as long as we aren't under zero3 - if self.args.should_save: - self._save(output_dir) - - if is_deepspeed_zero3_enabled(): - # It's too complicated to try to override different places where the weights dump gets - # saved, so since under zero3 the file is bogus, simply delete it. The user should - # either user deepspeed checkpoint to resume or to recover full weights use - # zero_to_fp32.py stored in the checkpoint. + if self.is_fsdp_enabled: + # remove the dummy state_dict saved above if self.args.should_save: - file = os.path.join(output_dir, WEIGHTS_NAME) - if os.path.isfile(file): - # logger.info(f"deepspeed zero3: removing {file}, see zero_to_fp32.py to recover weights") - os.remove(file) + for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]: + file = os.path.join(output_dir, filename) + if os.path.isfile(file): + os.remove(file) + save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir) - # now save the real model if stage3_gather_16bit_weights_on_model_save=True - # if false it will not be saved. - # This must be called on all ranks - if not self.deepspeed.save_16bit_model(output_dir, WEIGHTS_NAME): - logger.warning( - "deepspeed.save_16bit_model didn't save the model, since" - " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" - " zero_to_fp32.py to recover weights" - ) - self.deepspeed.save_checkpoint(output_dir) + elif self.is_deepspeed_enabled: + # this takes care of everything as long as we aren't under zero3 + if version.parse(accelerate_version) <= version.parse("0.20.3"): + raise ValueError("Install Accelerate from main branch") + try: + state_dict = self.accelerator.get_state_dict(self.deepspeed) + if self.args.should_save: + self._save(output_dir, state_dict=state_dict) + except ValueError: + logger.warning( + " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use" + " zero_to_fp32.py to recover weights" + ) + self.model_wrapped.save_checkpoint(output_dir) elif self.args.should_save: self._save(output_dir) @@ -2814,30 +2843,29 @@ class Trainer: output_dir = output_dir if output_dir is not None else self.args.output_dir os.makedirs(output_dir, exist_ok=True) logger.info(f"Saving model checkpoint to {output_dir}") + + supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel) # Save a trained model and configuration using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` - if not isinstance(self.model, PreTrainedModel): - if isinstance(unwrap_model(self.model), PreTrainedModel): - if state_dict is None: - state_dict = self.model.state_dict() - unwrap_model(self.model).save_pretrained(output_dir, state_dict=filtered_state_dict) + if not isinstance(self.model, supported_classes): + if state_dict is None: + state_dict = self.model.state_dict() + + if isinstance(unwrap_model(self.model), supported_classes): + unwrap_model(self.model).save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) else: logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") - if state_dict is None: - state_dict = self.model.state_dict() - torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + if self.args.save_safetensors: + safetensors.torch.save_file(state_dict, os.path.join(output_dir, SAFE_WEIGHTS_NAME)) + else: + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) else: - if self.save_prefixencoder: - print("Saving PrefixEncoder") - state_dict = self.model.state_dict() - filtered_state_dict = {} - for k, v in self.model.named_parameters(): - if v.requires_grad: - filtered_state_dict[k] = state_dict[k] - self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) - else: - print("Saving the whole model") - self.model.save_pretrained(output_dir, state_dict=state_dict) + self.model.save_pretrained( + output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors + ) + if self.tokenizer is not None: self.tokenizer.save_pretrained(output_dir) @@ -2846,7 +2874,7 @@ class Trainer: def store_flos(self): # Storing the number of floating-point operations that went into the model - if self.args.local_rank != -1: + if self.args.parallel_mode == ParallelMode.DISTRIBUTED: self.state.total_flos += ( distributed_broadcast_scalars([self.current_flos], device=self.args.device).sum().item() ) @@ -2923,7 +2951,7 @@ class Trainer: Pass a dataset if you wish to override `self.eval_dataset`. If it is a [`~datasets.Dataset`], columns not accepted by the `model.forward()` method are automatically removed. It must implement the `__len__` method. - ignore_keys (`Lst[str]`, *optional*): + ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"eval"`): @@ -2988,7 +3016,7 @@ class Trainer: test_dataset (`Dataset`): Dataset to run the predictions on. If it is an `datasets.Dataset`, columns not accepted by the `model.forward()` method are automatically removed. Has to implement the method `__len__` - ignore_keys (`Lst[str]`, *optional*): + ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. metric_key_prefix (`str`, *optional*, defaults to `"test"`): @@ -3054,19 +3082,30 @@ class Trainer: prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only - # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init( - self, num_training_steps=0, resume_from_checkpoint=None, inference=True - ) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: @@ -3090,9 +3129,6 @@ class Trainer: # Do this before wrapping. eval_dataset = getattr(dataloader, "dataset", None) - if is_torch_tpu_available(): - dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) - if args.past_index >= 0: self._past = None @@ -3123,37 +3159,41 @@ class Trainer: # Prediction step loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + main_input_name = getattr(self.model, "main_input_name", "input_ids") + inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None if is_torch_tpu_available(): xm.mark_step() # Update containers on host if loss is not None: - losses = self._nested_gather(loss.repeat(batch_size)) - losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0) + losses = self.accelerator.gather_for_metrics((loss.repeat(batch_size))) + losses_host = losses if losses_host is None else nested_concat(losses_host, losses, padding_index=-100) if labels is not None: - labels = self._pad_across_processes(labels) - labels = self._nested_gather(labels) - labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + labels = self.accelerator.pad_across_processes(labels, dim=1, pad_index=-100) if inputs_decode is not None: - inputs_decode = self._pad_across_processes(inputs_decode) - inputs_decode = self._nested_gather(inputs_decode) + inputs_decode = self.accelerator.pad_across_processes(inputs_decode, dim=1, pad_index=-100) + inputs_decode = self.accelerator.gather_for_metrics((inputs_decode)) inputs_host = ( inputs_decode if inputs_host is None else nested_concat(inputs_host, inputs_decode, padding_index=-100) ) if logits is not None: - logits = self._pad_across_processes(logits) - logits = self._nested_gather(logits) + logits = self.accelerator.pad_across_processes(logits, dim=1, pad_index=-100) if self.preprocess_logits_for_metrics is not None: logits = self.preprocess_logits_for_metrics(logits, labels) + logits = self.accelerator.gather_for_metrics((logits)) preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) + + if labels is not None: + labels = self.accelerator.gather_for_metrics((labels)) + labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: + if args.eval_accumulation_steps is not None and self.accelerator.sync_gradients: if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) @@ -3211,17 +3251,6 @@ class Trainer: if num_samples == 0 and observed_num_examples > 0: num_samples = observed_num_examples - # Number of losses has been rounded to a multiple of batch_size and in a distributed training, the number of - # samplers has been rounded to a multiple of batch_size, so we truncate. - if all_losses is not None: - all_losses = all_losses[:num_samples] - if all_preds is not None: - all_preds = nested_truncate(all_preds, num_samples) - if all_labels is not None: - all_labels = nested_truncate(all_labels, num_samples) - if all_inputs is not None: - all_inputs = nested_truncate(all_inputs, num_samples) - # Metrics! if self.compute_metrics is not None and all_preds is not None and all_labels is not None: if args.include_inputs_for_metrics: @@ -3261,45 +3290,12 @@ class Trainer: tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) - elif self.args.local_rank != -1: + elif (self.args.distributed_state is not None and self.args.distributed_state.distributed_type != "NO") or ( + self.args.distributed_state is None and self.args.local_rank != -1 + ): tensors = distributed_concat(tensors) return tensors - # Copied from Accelerate. - def _pad_across_processes(self, tensor, pad_index=-100): - """ - Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so - they can safely be gathered. - """ - if isinstance(tensor, (list, tuple)): - return type(tensor)(self._pad_across_processes(t, pad_index=pad_index) for t in tensor) - elif isinstance(tensor, dict): - return type(tensor)({k: self._pad_across_processes(v, pad_index=pad_index) for k, v in tensor.items()}) - elif not isinstance(tensor, torch.Tensor): - raise TypeError( - f"Can't pad the values of type {type(tensor)}, only of nested list/tuple/dicts of tensors." - ) - - if len(tensor.shape) < 2: - return tensor - # Gather all sizes - size = torch.tensor(tensor.shape, device=tensor.device)[None] - sizes = self._nested_gather(size).cpu() - - max_size = max(s[1] for s in sizes) - # When extracting XLA graphs for compilation, max_size is 0, - # so use inequality to avoid errors. - if tensor.shape[1] >= max_size: - return tensor - - # Then pad to the maximum size - old_size = tensor.shape - new_size = list(old_size) - new_size[1] = max_size - new_tensor = tensor.new_zeros(tuple(new_size)) + pad_index - new_tensor[:, : old_size[1]] = tensor - return new_tensor - def prediction_step( self, model: nn.Module, @@ -3322,7 +3318,7 @@ class Trainer: argument `labels`. Check your model's documentation for all accepted arguments. prediction_loss_only (`bool`): Whether or not to return the loss only. - ignore_keys (`Lst[str]`, *optional*): + ignore_keys (`List[str]`, *optional*): A list of keys in the output of your model (if it is a dictionary) that should be ignored when gathering predictions. @@ -3423,34 +3419,61 @@ class Trainer: else: return 0 + def init_hf_repo(self): + """ + Initializes a git repo in `self.args.hub_model_id`. + """ + # Only on process zero + if not self.is_world_process_zero(): + return + + if self.args.hub_model_id is None: + repo_name = Path(self.args.output_dir).absolute().name + else: + repo_name = self.args.hub_model_id + + repo_url = create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) + self.hub_model_id = repo_url.repo_id + self.push_in_progress = None + def init_git_repo(self, at_init: bool = False): """ Initializes a git repo in `self.args.hub_model_id`. + + + This function is deprecated and will be removed in v4.34.0 of Transformers. + + + Args: at_init (`bool`, *optional*, defaults to `False`): Whether this function is called before any training or not. If `self.args.overwrite_output_dir` is `True` and `at_init` is `True`, the path to the repo (which is `self.args.output_dir`) might be wiped out. """ + warnings.warn( + "`Trainer.init_git_repo` is deprecated and will be removed in v4.34.0 of Transformers. Use " + "`Trainer.init_hf_repo` instead." + ) if not self.is_world_process_zero(): return - if self.args.hub_model_id is None: - repo_name = Path(self.args.output_dir).absolute().name - else: - repo_name = self.args.hub_model_id - if "/" not in repo_name: - repo_name = get_full_repo_name(repo_name, token=self.args.hub_token) - # Make sure the repo exists. - create_repo(repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True) + # Make sure the repo exists + retrieve "real" repo_id + repo_name = self.args.hub_model_id + if repo_name is None: + repo_name = Path(self.args.output_dir).absolute().name + repo_id = create_repo( + repo_id=repo_name, token=self.args.hub_token, private=self.args.hub_private_repo, exist_ok=True + ).repo_id + try: - self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + self.repo = Repository(self.args.output_dir, clone_from=repo_id, token=self.args.hub_token) except EnvironmentError: if self.args.overwrite_output_dir and at_init: # Try again after wiping output_dir shutil.rmtree(self.args.output_dir) - self.repo = Repository(self.args.output_dir, clone_from=repo_name, token=self.args.hub_token) + self.repo = Repository(self.args.output_dir, clone_from=repo_id, token=self.args.hub_token) else: raise @@ -3530,13 +3553,15 @@ class Trainer: # Only push from one node. if not self.is_world_process_zero() or self.args.hub_strategy == HubStrategy.END: return - # If we haven't finished the last push, we don't do this one. - if self.push_in_progress is not None and not self.push_in_progress.is_done: + # If we haven't finished the last push, we don't do this one unless args.hub_always_push=True. + if not self.args.hub_always_push and self.push_in_progress is not None and not self.push_in_progress.is_done(): return output_dir = self.args.output_dir # To avoid a new synchronization of all model weights, we just copy the file from the checkpoint folder - modeling_files = [CONFIG_NAME, WEIGHTS_NAME] + modeling_files = [CONFIG_NAME, WEIGHTS_NAME, SAFE_WEIGHTS_NAME] + if is_peft_available(): + modeling_files.extend([ADAPTER_CONFIG_NAME, ADAPTER_WEIGHTS_NAME, ADAPTER_SAFE_WEIGHTS_NAME]) for modeling_file in modeling_files: if os.path.isfile(os.path.join(checkpoint_folder, modeling_file)): shutil.copy(os.path.join(checkpoint_folder, modeling_file), os.path.join(output_dir, modeling_file)) @@ -3546,49 +3571,64 @@ class Trainer: # Same for the training arguments torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) - try: - if self.args.hub_strategy == HubStrategy.CHECKPOINT: - # Temporarily move the checkpoint just saved for the push - tmp_checkpoint = os.path.join(output_dir, "last-checkpoint") - # We have to remove the "last-checkpoint" dir if it exists, otherwise the checkpoint is moved as a - # subfolder. - if os.path.isdir(tmp_checkpoint): - shutil.rmtree(tmp_checkpoint) - shutil.move(checkpoint_folder, tmp_checkpoint) + if self.args.save_strategy == IntervalStrategy.STEPS: + commit_message = f"Training in progress, step {self.state.global_step}" + else: + commit_message = f"Training in progress, epoch {int(self.state.epoch)}" - if self.args.save_strategy == IntervalStrategy.STEPS: - commit_message = f"Training in progress, step {self.state.global_step}" - else: - commit_message = f"Training in progress, epoch {int(self.state.epoch)}" - _, self.push_in_progress = self.repo.push_to_hub( - commit_message=commit_message, blocking=False, auto_lfs_prune=True + model_push_job = upload_folder( + repo_id=self.hub_model_id, + folder_path=output_dir, + commit_message=commit_message, + token=self.args.hub_token, + run_as_future=True, + ignore_patterns=["_*", "**/*"], + ) + + push_jobs = [model_push_job] + + if self.args.hub_strategy in [HubStrategy.CHECKPOINT, HubStrategy.ALL_CHECKPOINTS]: + path_in_repo = ( + "last-checkpoint" if self.args.hub_strategy == HubStrategy.CHECKPOINT else Path(checkpoint_folder).name ) - finally: - if self.args.hub_strategy == HubStrategy.CHECKPOINT: - # Move back the checkpoint to its place - shutil.move(tmp_checkpoint, checkpoint_folder) + checkpoint_push = upload_folder( + repo_id=self.hub_model_id, + folder_path=checkpoint_folder, + path_in_repo=path_in_repo, + commit_message=commit_message + ", checkpoint", + token=self.args.hub_token, + run_as_future=True, + ) + push_jobs.append(checkpoint_push) + + if self.push_in_progress is None or self.push_in_progress.is_done(): + self.push_in_progress = PushInProgress(push_jobs) + else: + self.push_in_progress.jobs.extend(push_jobs) + + def _finish_current_push(self): + if not hasattr(self, "push_in_progress"): + return + if self.push_in_progress is not None and not self.push_in_progress.is_done(): + logger.info("Waiting for the current checkpoint push to be finished, this might take a couple of minutes.") + self.push_in_progress.wait_until_done() def push_to_hub(self, commit_message: Optional[str] = "End of training", blocking: bool = True, **kwargs) -> str: """ - Upload *self.model* and *self.tokenizer* to the 🤗 model hub on the repo *self.args.hub_model_id*. + Upload `self.model` and `self.tokenizer` to the 🤗 model hub on the repo `self.args.hub_model_id`. Parameters: commit_message (`str`, *optional*, defaults to `"End of training"`): Message to commit while pushing. blocking (`bool`, *optional*, defaults to `True`): Whether the function should return only when the `git push` has finished. - kwargs: + kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to [`~Trainer.create_model_card`]. Returns: - The url of the commit of your model in the given repository if `blocking=False`, a tuple with the url of - the commit and an object to track the progress of the commit if `blocking=True` + The URL of the repository where the model was pushed if `blocking=False`, or a `Future` object tracking the + progress of the commit if `blocking=True`. """ - # If a user calls manually `push_to_hub` with `self.args.push_to_hub = False`, we try to create the repo but - # it might fail. - if not hasattr(self, "repo"): - self.init_git_repo() - model_name = kwargs.pop("model_name", None) if model_name is None and self.args.should_save: if self.args.hub_model_id is None: @@ -3596,6 +3636,10 @@ class Trainer: else: model_name = self.args.hub_model_id.split("/")[-1] + # In case the user calls this method with args.push_to_hub = False + if self.hub_model_id is None: + self.init_hf_repo() + # Needs to be executed on all processes for TPU training, but will only save on the processed determined by # self.args.should_save. self.save_model(_internal_call=True) @@ -3604,25 +3648,19 @@ class Trainer: if not self.is_world_process_zero(): return - # Cancel any async push in progress if blocking=True. The commits will all be pushed together. - if blocking and self.push_in_progress is not None and not self.push_in_progress.is_done: - self.push_in_progress._process.kill() - self.push_in_progress = None + self.create_model_card(model_name=model_name, **kwargs) - git_head_commit_url = self.repo.push_to_hub( - commit_message=commit_message, blocking=blocking, auto_lfs_prune=True + # Wait for the current upload to be finished. + self._finish_current_push() + + return upload_folder( + repo_id=self.hub_model_id, + folder_path=self.args.output_dir, + commit_message=commit_message, + token=self.args.hub_token, + run_as_future=not blocking, + ignore_patterns=["_*", "**/*"], ) - # push separately the model card to be independant from the rest of the model - if self.args.should_save: - self.create_model_card(model_name=model_name, **kwargs) - try: - self.repo.push_to_hub( - commit_message="update model card README.md", blocking=blocking, auto_lfs_prune=True - ) - except EnvironmentError as exc: - logger.error(f"Error pushing update to the model card. Please read logs and retry.\n${exc}") - - return git_head_commit_url # # Deprecated code @@ -3648,22 +3686,30 @@ class Trainer: prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only - # if eval is called w/o train init deepspeed here - if args.deepspeed and not self.deepspeed: - # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval - # from the checkpoint eventually - deepspeed_engine, _, _ = deepspeed_init(self, num_training_steps=0, resume_from_checkpoint=None) - self.model = deepspeed_engine.module - self.model_wrapped = deepspeed_engine - self.deepspeed = deepspeed_engine - # XXX: we don't need optim/sched for inference, but this needs to be sorted out, since - # for example the Z3-optimizer is a must for zero3 to work even for inference - what we - # don't need is the deepspeed basic optimizer which is self.optimizer.optimizer - deepspeed_engine.optimizer.optimizer = None - deepspeed_engine.lr_scheduler = None + # if eval is called w/o train, handle model prep here + if self.is_deepspeed_enabled and self.deepspeed is None: + _, _ = deepspeed_init(self, num_training_steps=0, inference=True) model = self._wrap_model(self.model, training=False, dataloader=dataloader) + if len(self.accelerator._models) == 0 and model is self.model: + model = ( + self.accelerator.prepare(model) + if self.is_deepspeed_enabled + else self.accelerator.prepare_model(model, evaluation_mode=True) + ) + + if self.is_fsdp_enabled: + self.model = model + + # for the rest of this function `model` is the outside model, whether it was wrapped or not + if model is not self.model: + self.model_wrapped = model + + # backward compatibility + if self.is_deepspeed_enabled: + self.deepspeed = self.model_wrapped + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called # while ``train`` is running, cast it to the right dtype first and then put on device if not self.is_in_train: @@ -3697,9 +3743,6 @@ class Trainer: model.eval() - if is_torch_tpu_available(): - dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) - if args.past_index >= 0: self._past = None @@ -3707,7 +3750,8 @@ class Trainer: for step, inputs in enumerate(dataloader): loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys) - inputs_decode = self._prepare_input(inputs["input_ids"]) if args.include_inputs_for_metrics else None + main_input_name = getattr(self.model, "main_input_name", "input_ids") + inputs_decode = self._prepare_input(inputs[main_input_name]) if args.include_inputs_for_metrics else None if loss is not None: losses = loss.repeat(batch_size) @@ -3785,7 +3829,7 @@ class Trainer: tensors = nested_xla_mesh_reduce(tensors, name) elif is_sagemaker_mp_enabled(): tensors = smp_gather(tensors) - elif self.args.local_rank != -1: + elif self.args.parallel_mode == ParallelMode.DISTRIBUTED: tensors = distributed_concat(tensors) return nested_numpify(tensors) @@ -3828,3 +3872,37 @@ class Trainer: if not self.repo.is_repo_clean(): self.repo.git_commit("Add *.sagemaker patterns to .gitignore.") self.repo.git_push() + + def create_accelerator_and_postprocess(self): + grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} + if version.parse(accelerate_version) > version.parse("0.20.3"): + grad_acc_kwargs["sync_with_dataloader"] = False + gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) + + # create accelerator object + self.accelerator = Accelerator( + dispatch_batches=self.args.dispatch_batches, + deepspeed_plugin=self.args.deepspeed_plugin, + gradient_accumulation_plugin=gradient_accumulation_plugin, + ) + + # deepspeed and accelerate flags covering both trainer args and accelerate launcher + self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None + self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None + + # post accelerator creation setup + if self.is_fsdp_enabled: + fsdp_plugin = self.accelerator.state.fsdp_plugin + fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( + "limit_all_gathers", fsdp_plugin.limit_all_gathers + ) + + if self.is_deepspeed_enabled: + if getattr(self.args, "hf_deepspeed_config", None) is None: + from transformers.deepspeed import HfTrainerDeepSpeedConfig + + ds_plugin = self.accelerator.state.deepspeed_plugin + + ds_plugin.hf_ds_config = HfTrainerDeepSpeedConfig(ds_plugin.hf_ds_config.config) + ds_plugin.deepspeed_config = ds_plugin.hf_ds_config.config + ds_plugin.hf_ds_config.trainer_config_process(self.args) diff --git a/ptuning/trainer_seq2seq.py b/ptuning/trainer_seq2seq.py index 19d5cf1..05be504 100644 --- a/ptuning/trainer_seq2seq.py +++ b/ptuning/trainer_seq2seq.py @@ -12,22 +12,68 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import os import torch from torch import nn from torch.utils.data import Dataset +from transformers.data.data_collator import DataCollator from transformers.deepspeed import is_deepspeed_zero3_enabled -from trainer import Trainer -from transformers.trainer_utils import PredictionOutput +from transformers import Trainer +from transformers.modeling_utils import PreTrainedModel, unwrap_model, WEIGHTS_NAME +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.trainer_callback import TrainerCallback +from transformers.trainer_utils import EvalPrediction, PredictionOutput +from transformers.training_args import TrainingArguments from transformers.utils import logging - logger = logging.get_logger(__name__) +TRAINING_ARGS_NAME = "training_args.bin" class Seq2SeqTrainer(Trainer): + def __init__(self, save_prefixencoder=None, **kwargs): + super().__init__(**kwargs) + self.save_prefixencoder = save_prefixencoder + + def _save(self, output_dir: Optional[str] = None, state_dict=None): + # If we are executing this function, we are the process zero, so we don't check for that. + output_dir = output_dir if output_dir is not None else self.args.output_dir + os.makedirs(output_dir, exist_ok=True) + logger.info(f"Saving model checkpoint to {output_dir}") + # Save a trained model and configuration using `save_pretrained()`. + # They can then be reloaded using `from_pretrained()` + if not isinstance(self.model, PreTrainedModel): + if isinstance(unwrap_model(self.model), PreTrainedModel): + if state_dict is None: + state_dict = self.model.state_dict() + unwrap_model(self.model).save_pretrained(output_dir, state_dict=filtered_state_dict) + else: + logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") + if state_dict is None: + state_dict = self.model.state_dict() + torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) + else: + if self.save_prefixencoder: + print("Saving PrefixEncoder") + state_dict = self.model.state_dict() + filtered_state_dict = {} + for k, v in self.model.named_parameters(): + if v.requires_grad: + filtered_state_dict[k] = state_dict[k] + self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) + else: + print("Saving the whole model") + self.model.save_pretrained(output_dir, state_dict=state_dict) + if self.tokenizer is not None: + self.tokenizer.save_pretrained(output_dir) + + # Good practice: save your training arguments together with the trained model + torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) + + def evaluate( self, eval_dataset: Optional[Dataset] = None, diff --git a/requirements.txt b/requirements.txt index fb8d79f..e3636f3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,13 @@ protobuf -transformers==4.27.1 +transformers==4.32.0 cpm_kernels torch>=1.10 gradio mdtex2html sentencepiece -accelerate \ No newline at end of file +accelerate +datasets +jieba +rouge_chinese +deepspeed +