mirror of https://github.com/hpcaitech/ColossalAI
support instrcut training (#3230)
parent
9bc702ab48
commit
bd39877da4
|
@ -119,10 +119,15 @@ def preprocess(
|
|||
class AlpacaDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_length: int=None):
|
||||
super(AlpacaDataset, self).__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
|
||||
if max_length is not None:
|
||||
logger.info(f"Truncating data to max length {max_length}...")
|
||||
list_data_dict = [example for example in list_data_dict if len(example["input"]) <= max_length]
|
||||
|
||||
logger.info("Formatting inputs...")
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
|
|
|
@ -60,3 +60,6 @@ class Actor(LoRAModule):
|
|||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
def get_base_model(self):
|
||||
return self.model
|
|
@ -36,3 +36,5 @@ class LlamaLM(LM):
|
|||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
||||
|
|
|
@ -61,13 +61,15 @@ class SFTTrainer(ABC):
|
|||
# train
|
||||
self.model.train()
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
prompt_ids = batch["input_ids"]
|
||||
p_mask = batch["attention_mask"]
|
||||
labels = batch["labels"]
|
||||
prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
p_mask = p_mask.squeeze(1).cuda()
|
||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||
labels = batch["labels"].to(torch.cuda.current_device())
|
||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
# p_mask = p_mask.squeeze(1).cuda()
|
||||
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
loss, prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
prompt_logits = outputs.logits
|
||||
|
||||
# loss = self.loss_fn(prompt_logits, labels)
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
|
@ -83,13 +85,16 @@ class SFTTrainer(ABC):
|
|||
loss_sum = 0
|
||||
num_seen = 0
|
||||
for batch in self.eval_dataloader:
|
||||
prompt_ids = batch["input_ids"]
|
||||
p_mask = batch["attention_mask"]
|
||||
prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
p_mask = p_mask.squeeze(1).cuda()
|
||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||
labels = batch["labels"].to(torch.cuda.current_device())
|
||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
# p_mask = p_mask.squeeze(1).cuda()
|
||||
|
||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
# prompt_logits = outputs.logits
|
||||
|
||||
prompt_logits = self.model(prompt_ids, attention_mask=p_mask)
|
||||
loss = self.loss_fn(prompt_logits, prompt_ids)
|
||||
loss_sum += loss.item()
|
||||
num_seen += prompt_ids.size(0)
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from chatgpt.models.base import Actor
|
|||
from chatgpt.models.lora import LoraLinear
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
|
|
|
@ -16,6 +16,8 @@ from typing import Dict
|
|||
|
||||
import transformers
|
||||
|
||||
from ..models.llama.llama_lm import LlamaLM
|
||||
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
|
@ -60,6 +62,10 @@ def smart_tokenizer_and_embedding_resize(
|
|||
|
||||
if tokenizer.pad_token is None:
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
if isinstance(model, LlamaLM):
|
||||
model = model.get_base_model()
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
|
|
|
@ -93,25 +93,27 @@ def train(args):
|
|||
elif 'alpaca' in args.dataset:
|
||||
train_dataset = AlpacaDataset(tokenizer=tokenizer, data_path=args.dataset)
|
||||
eval_dataset = None
|
||||
eval_dataset
|
||||
data_collator = AlpacaDataCollator(tokenizer=tokenizer)
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
logger.info("Using Distributed Sampler")
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
if eval_dataset is not None:
|
||||
eval_sampler = DistributedSampler(eval_dataset, shuffle=False, seed=42, drop_last=False)
|
||||
else:
|
||||
sampler = None
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=(sampler is None), sampler=sampler, batch_size=args.batch_size)
|
||||
train_dataloader = DataLoader(train_dataset, shuffle=(train_sampler is None), sampler=train_sampler, batch_size=args.batch_size, collate_fn=data_collator)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset, batch_size=args.batch_size)
|
||||
eval_dataloader = DataLoader(eval_dataset, shuffle=(eval_sampler is None), sampler=eval_sampler, batch_size=args.batch_size, collate_fn=data_collator)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
sampler=sampler,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs)
|
||||
|
||||
|
@ -128,7 +130,7 @@ if __name__ == '__main__':
|
|||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt'], default='bloom')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default='yizhongw/self_instruct')
|
||||
parser.add_argument('--save_path', type=str, default='sft_ckpt.pth')
|
||||
|
|
|
@ -17,4 +17,10 @@ set_n_least_used_CUDA_VISIBLE_DEVICES 8
|
|||
|
||||
#torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain 'bigscience/bloomz-560m' --model 'bloom' --strategy colossalai_zero2 --log_interval 10
|
||||
#torchrun --standalone --nproc_per_node=8 train_sft.py --model 'gpt2' --strategy colossalai_zero2 --batch_size 1 --log_interval 10
|
||||
torchrun --standalone --nproc_per_node=2 train_sft.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2 --log_interval 10
|
||||
torchrun --standalone --nproc_per_node=8 train_sft.py \
|
||||
--pretrain "/data/personal/nus-mql/LLAMA-7B" \
|
||||
--model 'llama' \
|
||||
--strategy colossalai_zero2 \
|
||||
--log_interval 10 \
|
||||
--save_path /data/personal/nus-mql/Coati-7B \
|
||||
--dataset /data/personal/nus-mql/stanford_alpaca/alpaca_data.json
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.1.0
|
||||
1.0.0
|
||||
|
|
Loading…
Reference in New Issue