mirror of https://github.com/hpcaitech/ColossalAI
[eval] update llama npu eval (#5366)
parent
44ca61a22b
commit
a5756a8720
|
@ -3,6 +3,8 @@ from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .huggingface import HuggingFaceModel
|
from .huggingface import HuggingFaceModel
|
||||||
|
|
||||||
IGNORE_INDEX = -100
|
IGNORE_INDEX = -100
|
||||||
|
@ -126,9 +128,9 @@ class ChatGLMModel(HuggingFaceModel):
|
||||||
"""
|
"""
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||||
torch.cuda.current_device()
|
get_current_device()
|
||||||
)
|
)
|
||||||
|
|
||||||
outputs = self.model(input_ids)[0]
|
outputs = self.model(input_ids)[0]
|
||||||
|
@ -197,7 +199,7 @@ class ChatGLM2Model(ChatGLMModel):
|
||||||
truncation=True,
|
truncation=True,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
max_length=self.model_max_length - max_new_tokens,
|
max_length=self.model_max_length - max_new_tokens,
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
|
|
||||||
# Set output_scores=True to get prediction scores.
|
# Set output_scores=True to get prediction scores.
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
|
|
|
@ -11,6 +11,7 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, AutoTokeni
|
||||||
|
|
||||||
from colossalai.logging import DistributedLogger
|
from colossalai.logging import DistributedLogger
|
||||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
|
|
||||||
|
@ -128,12 +129,12 @@ class HuggingFaceModel(BaseModel):
|
||||||
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
self.model = AutoModel.from_pretrained(path, **model_kwargs)
|
||||||
shard_former = ShardFormer(shard_config)
|
shard_former = ShardFormer(shard_config)
|
||||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||||
self.model.to(torch.cuda.current_device())
|
self.model.to(get_current_device())
|
||||||
|
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||||
else:
|
else:
|
||||||
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
self.model = AutoModel.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
@ -155,11 +156,11 @@ class HuggingFaceModel(BaseModel):
|
||||||
"""
|
"""
|
||||||
input_ids = torch.nn.utils.rnn.pad_sequence(
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
||||||
torch.cuda.current_device()
|
get_current_device()
|
||||||
)
|
)
|
||||||
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(torch.cuda.current_device())
|
attention_mask = input_ids.ne(self.tokenizer.pad_token_id).to(get_current_device())
|
||||||
|
|
||||||
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
|
outputs = self.model(input_ids, attention_mask=attention_mask)[0]
|
||||||
|
|
||||||
|
@ -464,7 +465,7 @@ class HuggingFaceModel(BaseModel):
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
return_token_type_ids=False,
|
return_token_type_ids=False,
|
||||||
max_length=self.model_max_length - max_new_tokens,
|
max_length=self.model_max_length - max_new_tokens,
|
||||||
).to(torch.cuda.current_device())
|
).to(get_current_device())
|
||||||
|
|
||||||
# Set output_scores=True to get prediction scores.
|
# Set output_scores=True to get prediction scores.
|
||||||
outputs = self.model.generate(
|
outputs = self.model.generate(
|
||||||
|
@ -598,12 +599,12 @@ class HuggingFaceCausalLM(HuggingFaceModel):
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs)
|
||||||
shard_former = ShardFormer(shard_config)
|
shard_former = ShardFormer(shard_config)
|
||||||
self.model, sharded_parameters = shard_former.optimize(self.model)
|
self.model, sharded_parameters = shard_former.optimize(self.model)
|
||||||
self.model.to(torch.cuda.current_device())
|
self.model.to(get_current_device())
|
||||||
|
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
raise NotImplementedError("ShardFormer for PEFT models is not implemented.")
|
||||||
else:
|
else:
|
||||||
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(torch.cuda.current_device())
|
self.model = AutoModelForCausalLM.from_pretrained(path, **model_kwargs).to(get_current_device())
|
||||||
if peft_path is not None:
|
if peft_path is not None:
|
||||||
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
self.model = PeftModel.from_pretrained(self.model, peft_path, is_trainable=False)
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch.distributed as dist
|
||||||
from colossal_eval import dataset, models, utils
|
from colossal_eval import dataset, models, utils
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.cluster import ProcessGroupMesh
|
from colossalai.cluster import ProcessGroupMesh
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.shardformer import ShardConfig
|
from colossalai.shardformer import ShardConfig
|
||||||
|
@ -82,6 +83,7 @@ def rm_and_merge(
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
colossalai.launch_from_torch(config={}, seed=42)
|
colossalai.launch_from_torch(config={}, seed=42)
|
||||||
|
accelerator = get_accelerator()
|
||||||
world_size = dist.get_world_size()
|
world_size = dist.get_world_size()
|
||||||
|
|
||||||
rank = dist.get_rank()
|
rank = dist.get_rank()
|
||||||
|
@ -235,10 +237,10 @@ def main(args):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"Rank {rank} peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB")
|
logger.info(f"Rank {rank} peak device mem: {accelerator.max_memory_allocated()/1024**3:.3f} GB")
|
||||||
|
|
||||||
del model_
|
del model_
|
||||||
torch.cuda.empty_cache()
|
accelerator.empty_cache()
|
||||||
|
|
||||||
dist.barrier()
|
dist.barrier()
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
|
Loading…
Reference in New Issue