Browse Source

support instrcut training (#3230)

pull/3221/head
Fazzie-Maqianli 2 years ago committed by GitHub
parent
commit
bd39877da4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 7
      applications/ChatGPT/chatgpt/dataset/sft_dataset.py
  2. 3
      applications/ChatGPT/chatgpt/models/base/actor.py
  3. 2
      applications/ChatGPT/chatgpt/models/llama/llama_lm.py
  4. 29
      applications/ChatGPT/chatgpt/trainer/sft.py
  5. 1
      applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py
  6. 6
      applications/ChatGPT/chatgpt/utils/tokenizer_utils.py
  7. 18
      applications/ChatGPT/examples/train_sft.py
  8. 8
      applications/ChatGPT/examples/train_sft.sh
  9. 2
      applications/ChatGPT/version.txt

7
applications/ChatGPT/chatgpt/dataset/sft_dataset.py

@ -119,10 +119,15 @@ def preprocess(
class AlpacaDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int=None):
super(AlpacaDataset, self).__init__()
logger.info("Loading data...")
list_data_dict = jload(data_path)
logger.info(f"Loaded {len(list_data_dict)} examples.")
if max_length is not None:
logger.info(f"Truncating data to max length {max_length}...")
list_data_dict = [example for example in list_data_dict if len(example["input"]) <= max_length]
logger.info("Formatting inputs...")
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]

3
applications/ChatGPT/chatgpt/models/base/actor.py

@ -60,3 +60,6 @@ class Actor(LoRAModule):
logits = output['logits']
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]
def get_base_model(self):
return self.model

2
applications/ChatGPT/chatgpt/models/llama/llama_lm.py

@ -36,3 +36,5 @@ class LlamaLM(LM):
super().__init__(model, lora_rank, lora_train_bias)
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)

29
applications/ChatGPT/chatgpt/trainer/sft.py

@ -61,13 +61,15 @@ class SFTTrainer(ABC):
# train
self.model.train()
for batch_id, batch in enumerate(self.train_dataloader):
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
labels = batch["labels"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
loss = outputs.loss
prompt_logits = outputs.logits
# loss = self.loss_fn(prompt_logits, labels)
self.strategy.backward(loss, self.model, self.optimizer)
@ -83,13 +85,16 @@ class SFTTrainer(ABC):
loss_sum = 0
num_seen = 0
for batch in self.eval_dataloader:
prompt_ids = batch["input_ids"]
p_mask = batch["attention_mask"]
prompt_ids = prompt_ids.squeeze(1).cuda()
p_mask = p_mask.squeeze(1).cuda()
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
labels = batch["labels"].to(torch.cuda.current_device())
# prompt_ids = prompt_ids.squeeze(1).cuda()
# p_mask = p_mask.squeeze(1).cuda()
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
loss = outputs.loss
# prompt_logits = outputs.logits
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
loss = self.loss_fn(prompt_logits, prompt_ids)
loss_sum += loss.item()
num_seen += prompt_ids.size(0)

1
applications/ChatGPT/chatgpt/trainer/strategies/colossalai.py

@ -9,7 +9,6 @@ from chatgpt.models.base import Actor
from chatgpt.models.lora import LoraLinear
from torch.optim import Optimizer
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

6
applications/ChatGPT/chatgpt/utils/tokenizer_utils.py

@ -16,6 +16,8 @@ from typing import Dict
import transformers
from ..models.llama.llama_lm import LlamaLM
DEFAULT_PAD_TOKEN = "[PAD]"
DEFAULT_EOS_TOKEN = "</s>"
DEFAULT_BOS_TOKEN = "</s>"
@ -60,6 +62,10 @@ def smart_tokenizer_and_embedding_resize(
if tokenizer.pad_token is None:
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
if isinstance(model, LlamaLM):
model = model.get_base_model()
model.resize_token_embeddings(len(tokenizer))
if num_new_tokens > 0:

18
applications/ChatGPT/examples/train_sft.py

@ -93,25 +93,27 @@ def train(args):
elif 'alpaca' in args.dataset:
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
eval_dataset = None
eval_dataset
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
if dist.is_initialized() and dist.get_world_size() > 1:
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
logger.info("Using Distributed Sampler")
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
if eval_dataset is not None:
eval_sampler = DistributedSampler(eval_dataset, shuffle=False, seed=42, drop_last=False)
else:
sampler = None
train_sampler = None
eval_sampler = None
train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
train_dataloader = DataLoader(train_dataset, shuffle=(train_sampler is None), sampler=train_sampler, batch_size=args.batch_size, collate_fn=data_collator)
if eval_dataset is not None:
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_collator)
else:
eval_dataloader = None
trainer = SFTTrainer(model=model,
strategy=strategy,
optim=optim,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
sampler=sampler,
batch_size=args.batch_size,
max_epochs=args.max_epochs)
@ -128,7 +130,7 @@ if __name__ == '__main__':
parser.add_argument('--strategy',
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
default='naive')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
parser.add_argument('--pretrain', type=str, default=None)
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')

8
applications/ChatGPT/examples/train_sft.sh

@ -17,4 +17,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 8
#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 --log_interval 10
torchrun --standalone --nproc_per_node=8 train_sft.py \
--pretrain "/data/personal/nus-mql/LLAMA-7B" \
--model 'llama' \
--strategy colossalai_zero2 \
--log_interval 10 \
--save_path /data/personal/nus-mql/Coati-7B \
--dataset /data/personal/nus-mql/stanford_alpaca/alpaca_data.json

2
applications/ChatGPT/version.txt

@ -1 +1 @@
0.1.0
1.0.0

Loading…
Cancel
Save