mirror of https://github.com/hpcaitech/ColossalAI
add ppo
parent
eb6337f07f
commit
6a6634b6e8
|
@ -162,4 +162,5 @@ coverage.xml
|
|||
|
||||
# log, test files - ColossalChat
|
||||
applications/ColossalChat/logs
|
||||
applications/ColossalChat/wandb
|
||||
applications/ColossalChat/tests/logs
|
||||
|
|
|
@ -14,6 +14,7 @@ from colossalai.booster.plugin import HybridParallelPlugin
|
|||
from colossalai.initialize import launch
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.shardformer.policies.auto_policy import get_autopolicy
|
||||
|
||||
from .comm import ray_broadcast_tensor_dict
|
||||
from .utils import bind_batch, post_recv, unbind_batch
|
||||
|
@ -76,6 +77,10 @@ class BaseConsumer:
|
|||
plugin_config.update(self.plugin_config)
|
||||
self.plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.booster = Booster(plugin=self.plugin)
|
||||
if hasattr(self, "critic_model"):
|
||||
plugin_config.update({"custom_policy": get_autopolicy(self.critic_model.model)})
|
||||
self.critic_plugin = HybridParallelPlugin(**plugin_config)
|
||||
self.critic_booster = Booster(plugin=self.critic_plugin)
|
||||
self.dp_rank = dist.get_rank(self.plugin.dp_group)
|
||||
self.dp_size = dist.get_world_size(self.plugin.dp_group)
|
||||
|
||||
|
|
|
@ -60,6 +60,7 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||
self.model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.generate_config = generate_config.copy()
|
||||
self.generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.generate_config["tokenizer"] = tokenizer
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@torch.no_grad()
|
||||
|
@ -76,21 +77,26 @@ class TransformersInferenceBackend(BaseInferenceBackend):
|
|||
action_log_probs.append(log_probs_from_logits(logits[:, None, :], new_token_ids[:, i : i + 1]))
|
||||
action_log_probs = torch.cat(action_log_probs, dim=1)
|
||||
# get action mask
|
||||
response_idx = torch.zeros((new_token_ids.size(0), 2), dtype=torch.int).to(get_current_device())
|
||||
action_mask = torch.ones_like(new_token_ids, dtype=attention_mask.dtype)
|
||||
if self.tokenizer.eos_token_id is not None:
|
||||
for indices in torch.nonzero(new_token_ids == self.tokenizer.eos_token_id):
|
||||
action_mask[indices[0], indices[1] + 1 :] = 0
|
||||
response_idx[:, 0] = input_len
|
||||
response_idx[:, 1] = input_len + action_mask.sum(dim=1) - 1
|
||||
|
||||
if attention_mask.size(0) != action_mask.size(0):
|
||||
assert action_mask.size(0) % attention_mask.size(0) == 0
|
||||
attention_mask = attention_mask.repeat_interleave(action_mask.size(0) // attention_mask.size(0), dim=0)
|
||||
|
||||
attention_mask = torch.cat((attention_mask, action_mask), dim=1)
|
||||
|
||||
data = {
|
||||
"input_ids": out.sequences,
|
||||
"attention_mask": attention_mask,
|
||||
"action_log_probs": action_log_probs,
|
||||
"action_mask": action_mask,
|
||||
"response_idx": response_idx
|
||||
}
|
||||
return data
|
||||
|
||||
|
@ -154,7 +160,6 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||
)
|
||||
FORCE_GENERATE_CONFIG = dict(
|
||||
logprobs=0,
|
||||
n=4,
|
||||
)
|
||||
|
||||
def __init__(self, model_config: Dict[str, Any], generate_config: Dict[str, Any], tokenizer: PreTrainedTokenizer):
|
||||
|
@ -167,7 +172,7 @@ class VLLMInferenceBackend(BaseInferenceBackend):
|
|||
generate_config.update(self.FORCE_GENERATE_CONFIG)
|
||||
self.generate_config = SamplingParams(**generate_config)
|
||||
self.tokenizer = tokenizer
|
||||
self.num_generations = self.FORCE_GENERATE_CONFIG["n"]
|
||||
self.num_generations = generate_config["n"]
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
|
||||
|
|
|
@ -4,11 +4,13 @@ import ray
|
|||
|
||||
from .consumer import SimpleConsumer
|
||||
from .grpo_consumer import GRPOConsumer
|
||||
from .ppo_consumer import PPOConsumer
|
||||
from .producer import SimpleProducer
|
||||
|
||||
ALGO_MAP = {
|
||||
"Simple": SimpleConsumer,
|
||||
"GRPO": GRPOConsumer,
|
||||
"PPO": PPOConsumer
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -45,3 +45,31 @@ class PolicyLoss(nn.Module):
|
|||
loss = loss.mean(dim=1)
|
||||
loss = loss.mean()
|
||||
return loss, skip, ratio_.max()
|
||||
|
||||
|
||||
class ValueLoss(nn.Module):
|
||||
"""
|
||||
Value Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.2) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(
|
||||
self,
|
||||
values: torch.Tensor,
|
||||
old_values: torch.Tensor,
|
||||
advantage: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
returns = advantage + old_values
|
||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - returns) ** 2
|
||||
surr2 = (values - returns) ** 2
|
||||
if action_mask is not None:
|
||||
# loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
|
||||
loss = torch.mean(masked_mean(torch.max(surr1, surr2), action_mask))
|
||||
else:
|
||||
loss = torch.mean(torch.max(surr1, surr2))
|
||||
return 0.5 * loss
|
|
@ -0,0 +1,262 @@
|
|||
from contextlib import nullcontext
|
||||
from typing import Optional
|
||||
|
||||
import ray
|
||||
import torch
|
||||
import wandb
|
||||
from coati.distributed.consumer import BaseConsumer
|
||||
from coati.distributed.loss import PolicyLoss, ValueLoss
|
||||
from coati.distributed.reward.reward_fn import math_reward_fn
|
||||
from coati.distributed.reward.verifiable_reward import VerifiableReward
|
||||
from coati.distributed.utils import calc_action_log_probs, compute_reward_ppo
|
||||
from coati.trainer.utils import all_reduce_mean
|
||||
from coati.models import Critic, disable_dropout
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
@ray.remote
|
||||
class PPOConsumer(BaseConsumer):
|
||||
def __init__(
|
||||
self,
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size=1,
|
||||
num_generations=1,
|
||||
gamma:float=1.0,
|
||||
lam:float=0.95,
|
||||
kl_coef:float=0.05,
|
||||
use_wandb=True,
|
||||
):
|
||||
super().__init__(
|
||||
num_producers,
|
||||
num_episodes,
|
||||
rank,
|
||||
world_size,
|
||||
master_addr,
|
||||
master_port,
|
||||
num_update_per_episode,
|
||||
num_recv_per_update,
|
||||
batch_size,
|
||||
model_config,
|
||||
plugin_config,
|
||||
microbatch_size,
|
||||
)
|
||||
self.gamma = gamma
|
||||
self.lam = lam
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
path = model_config.pop("path")
|
||||
self.policy_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.policy_model.train()
|
||||
self.policy_model.gradient_checkpointing_enable()
|
||||
self.critic_model = Critic(path, **model_config)
|
||||
self.critic_model.model.gradient_checkpointing_enable()
|
||||
self.critic_model.train()
|
||||
|
||||
# Disable dropout
|
||||
disable_dropout(self.policy_model)
|
||||
disable_dropout(self.critic_model)
|
||||
|
||||
self.optimizer = HybridAdam(self.policy_model.parameters(), lr=1e-6)
|
||||
self.critic_optimizer = HybridAdam(self.critic_model.parameters(), lr=1e-6)
|
||||
|
||||
self.accum_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_reward = torch.zeros(1, device=self.device)
|
||||
self.accum_kl = torch.zeros(1, device=self.device)
|
||||
self.accum_advantage = torch.zeros(1, device=self.device)
|
||||
self.accum_critic_loss = torch.zeros(1, device=self.device)
|
||||
self.accum_count = 0
|
||||
|
||||
# Reference model is initialized from policy model.
|
||||
self.reference_model = AutoModelForCausalLM.from_pretrained(path, **model_config)
|
||||
self.reference_model.eval()
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(path)
|
||||
self.pad_token_id = self.tokenizer.pad_token_id
|
||||
self.num_generations = num_generations
|
||||
|
||||
# Initialize verifiable reward.
|
||||
response_format_tags = {
|
||||
"think_start": {"text": "<think>", "num_occur": 1},
|
||||
"think_end": {"text": "</think>", "num_occur": 1},
|
||||
"answer_start": {"text": "<answer>", "num_occur": 1},
|
||||
"answer_end": {"text": "</answer>", "num_occur": 1},
|
||||
}
|
||||
self.reward_model = VerifiableReward(
|
||||
reward_fns=[math_reward_fn], tokenizer=self.tokenizer, tags=response_format_tags
|
||||
)
|
||||
|
||||
self.policy_loss_fn = PolicyLoss()
|
||||
self.critic_loss_fn = ValueLoss()
|
||||
self.global_step = 0
|
||||
if use_wandb and self.rank == 0:
|
||||
self.wandb_run = wandb.init(project="PPO-Test", sync_tensorboard=True)
|
||||
|
||||
def setup(self):
|
||||
super().setup()
|
||||
self.policy_model, self.optimizer, *_ = self.booster.boost(self.policy_model, self.optimizer)
|
||||
self.critic_model, self.critic_optimizer, *_ = self.critic_booster.boost(
|
||||
self.critic_model, self.critic_optimizer
|
||||
)
|
||||
self.reference_model, *_ = self.booster.boost(self.reference_model)
|
||||
|
||||
def step(self, step_idx: int, **kwargs) -> Optional[float]:
|
||||
"""
|
||||
Step data from policy model:
|
||||
[{
|
||||
"input_ids": torch.Tensor,
|
||||
"attention_mask": torch.Tensor,
|
||||
"action_mask": torch.Tensor,
|
||||
"action_log_probs": torch.Tensor,
|
||||
},
|
||||
...]
|
||||
Format:
|
||||
[batch_size, num_of_generation, prompt_length + response_length] --- <PAD>...<PAD><PROMPT>...<PROMPT><RESPONSE>...<RESPONSE><PAD>...<PAD>.
|
||||
"""
|
||||
|
||||
# Reshape to [batch_size x num_of_generation, prompt_length + response_length]
|
||||
data = {k: v.view(-1, v.size(-1)) for k, v in kwargs.items()}
|
||||
action_mask = data["action_mask"]
|
||||
num_action = action_mask.shape[1]
|
||||
old_action_log_probs = data["action_log_probs"].detach()
|
||||
|
||||
need_update = (step_idx + 1) % self.num_microbatches == 0
|
||||
|
||||
ctx = nullcontext() if need_update else self.booster.no_sync(self.policy_model, self.optimizer)
|
||||
with ctx:
|
||||
policy_model_logits = self.policy_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
|
||||
|
||||
with torch.no_grad():
|
||||
reference_model_logits = self.reference_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)["logits"]
|
||||
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
|
||||
|
||||
value = self.critic_model(
|
||||
input_ids=data["input_ids"],
|
||||
attention_mask=data["attention_mask"],
|
||||
)
|
||||
value = value[:, -num_action -1: -1] * action_mask
|
||||
|
||||
r = self.reward_model(
|
||||
data["input_ids"], gt_answer=data["gt_answer"], response_idx=data["response_idx"]
|
||||
)
|
||||
reward, kl = compute_reward_ppo(
|
||||
r, self.kl_coef, old_action_log_probs, reference_action_log_probs, action_mask=action_mask
|
||||
)
|
||||
|
||||
# Calculate advantages
|
||||
# reference: https://github.com/huggingface/trl/blob/e3244d2d096ff1e2e248c931d06d39e165e20623/trl/trainer/ppo_trainer.py#L514C17-L523C46lastgaelam = 0
|
||||
lastgaelam = 0
|
||||
advantage_reversed = []
|
||||
for t in reversed(range(num_action)):
|
||||
nextvalues = value[:, t + 1] if t < num_action - 1 else 0.0
|
||||
delta = reward[:, t] + self.gamma * nextvalues - value[:, t]
|
||||
lastgaelam = delta + self.gamma * self.lam * lastgaelam
|
||||
advantage_reversed.append(lastgaelam)
|
||||
advantage = torch.stack(advantage_reversed[::-1], axis=1) * action_mask
|
||||
advantage = advantage.detach()
|
||||
|
||||
# KL divergence for logging
|
||||
per_token_kl = (
|
||||
torch.exp(reference_action_log_probs - action_log_probs)
|
||||
- (reference_action_log_probs - action_log_probs)
|
||||
- 1
|
||||
)
|
||||
kl = torch.sum(per_token_kl * action_mask, dim=-1) / torch.sum(action_mask, dim=-1)
|
||||
|
||||
# Calculate Loss
|
||||
loss, skip_update, _ = self.policy_loss_fn(
|
||||
action_log_probs,
|
||||
old_action_log_probs,
|
||||
advantage,
|
||||
0, # kl is already included in the advantage
|
||||
action_mask,
|
||||
)
|
||||
|
||||
# Critic Loss
|
||||
# Hack: use the current value to approximate the old value, should be old value mathematically
|
||||
critic_loss = self.critic_loss_fn(
|
||||
value,
|
||||
value.detach(),
|
||||
advantage,
|
||||
action_mask=action_mask,
|
||||
)
|
||||
|
||||
if not skip_update:
|
||||
self.booster.backward(loss, self.optimizer)
|
||||
self.critic_booster.backward(critic_loss, self.critic_optimizer)
|
||||
loss = all_reduce_mean(loss, self.plugin)
|
||||
critic_loss = all_reduce_mean(critic_loss, self.plugin)
|
||||
r_mean = all_reduce_mean(r.mean(), self.plugin)
|
||||
kl = all_reduce_mean(kl.mean(), self.plugin)
|
||||
advantage = all_reduce_mean(advantage.mean(), self.plugin)
|
||||
self.accum_loss.add_(loss.data)
|
||||
self.accum_critic_loss.add_(critic_loss.data)
|
||||
self.accum_advantage.add_(advantage.data)
|
||||
self.accum_reward.add_(r_mean.data)
|
||||
self.accum_kl.add_(kl.data)
|
||||
self.accum_count += 1
|
||||
if self.rank == 0:
|
||||
print(f"input_ids: {data['input_ids'].shape}, reward: {r_mean.item()}")
|
||||
if need_update:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
self.critic_optimizer.step()
|
||||
self.critic_optimizer.zero_grad()
|
||||
loss_scalar = self.accum_loss.item()
|
||||
if self.rank == 0:
|
||||
print(
|
||||
"Loss:",
|
||||
self.accum_loss.item() / self.accum_count,
|
||||
"Reward:",
|
||||
self.accum_reward.item() / self.accum_count,
|
||||
"KL:",
|
||||
self.accum_kl.item() / self.accum_count,
|
||||
)
|
||||
if self.global_step % 3 == 0:
|
||||
for i in range(min(3, data["input_ids"].shape[0])):
|
||||
response_decoded_for_logging = self.tokenizer.decode(
|
||||
data["input_ids"][i], skip_special_tokens=True
|
||||
)
|
||||
response_reward_for_logging = r[i]
|
||||
print(f"###### Generation Sample {i} ######\nResponse: {response_decoded_for_logging}\nReward: {response_reward_for_logging}")
|
||||
self.wandb_run.log(
|
||||
{
|
||||
"train/loss": self.accum_loss.item() / self.accum_count,
|
||||
"train/reward": self.accum_reward.item() / self.accum_count,
|
||||
"train/kl": self.accum_kl.item() / self.accum_count,
|
||||
"train/critic_loss": self.accum_critic_loss.item() / self.accum_count,
|
||||
"train/advantage": self.accum_advantage.item() / self.accum_count,
|
||||
}
|
||||
)
|
||||
self.accum_loss.zero_()
|
||||
self.accum_reward.zero_()
|
||||
self.accum_kl.zero_()
|
||||
self.accum_advantage.zero_()
|
||||
self.accum_critic_loss.zero_()
|
||||
self.accum_count = 0
|
||||
self.global_step += 1
|
||||
return loss_scalar
|
||||
|
||||
def state_dict(self):
|
||||
self.policy_model._force_wait_all_gather()
|
||||
model = self.policy_model.unwrap()
|
||||
state_dict = model.state_dict()
|
||||
return state_dict
|
|
@ -154,6 +154,11 @@ class SimpleProducer(BaseProducer):
|
|||
|
||||
@torch.no_grad()
|
||||
def rollout(self, input_ids, attention_mask, **kwargs):
|
||||
if self.backend_cls.__name__ == "TransformersInferenceBackend":
|
||||
gt_answer = kwargs.pop("gt_answer")
|
||||
out = self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
out["gt_answer"] = gt_answer.to(out["input_ids"].device)
|
||||
return out
|
||||
return self.model.generate(input_ids, attention_mask, **kwargs)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
|
|
|
@ -19,8 +19,8 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
|
|||
return reward
|
||||
else:
|
||||
reward += 1.0
|
||||
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||
reward = reward + 2.0
|
||||
# if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
|
||||
# reward = reward + 2.0
|
||||
return reward
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
@ -99,3 +99,30 @@ def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch
|
|||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
def compute_reward_ppo(
|
||||
r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
reward_eps=5,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
log_probs: [batch_size, response_length]
|
||||
log_probs_base: [batch_size, response_length]
|
||||
action_mask: [batch_size, response_length]
|
||||
r: float
|
||||
Returns:
|
||||
reward: [batch_size, response_length]
|
||||
"""
|
||||
log_ratio = log_probs - log_probs_base # address numerical instability issue
|
||||
kl = -kl_coef * log_ratio * action_mask
|
||||
reward = kl
|
||||
r_clip = torch.clamp(r, -reward_eps, reward_eps)
|
||||
for i in range(action_mask.size(0)):
|
||||
assert action_mask[i].sum() > 0
|
||||
reward[i, : action_mask[i].sum()] += r_clip[i]
|
||||
reward[i, action_mask[i].sum() :] *= 0
|
||||
return reward, ((log_ratio * (log_ratio < 10)).exp() - 1 - log_ratio) * action_mask
|
|
@ -7,6 +7,7 @@ from coati.distributed.launch import launch_distributed
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument("-rm", "--reward_model", type=str, default="Qwen/Qwen2.5-7B")
|
||||
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
|
||||
parser.add_argument("-t", "--num-trainers", type=int, default=2)
|
||||
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
|
||||
|
@ -15,7 +16,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("-tbs", "--train-batch-size", type=int, default=16)
|
||||
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
|
||||
parser.add_argument("-b", "--backend", type=str, default="transformers")
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple, GPRO"])
|
||||
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GPRO", "PPO"])
|
||||
args = parser.parse_args()
|
||||
|
||||
ray.init(address="local", namespace="ray-example")
|
||||
|
@ -27,26 +28,27 @@ if __name__ == "__main__":
|
|||
top_p=0.8,
|
||||
)
|
||||
|
||||
if args.backend == "transformers":
|
||||
if args.backend == "transformers":
|
||||
inference_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
)
|
||||
train_model_config.update(
|
||||
dict(
|
||||
attn_implementation="flash_attention_2",
|
||||
use_flash_attention_2=True,
|
||||
torch_dtype=torch.bfloat16,
|
||||
use_cache=False,
|
||||
)
|
||||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_length=512,
|
||||
max_length=768,
|
||||
do_sample=True,
|
||||
max_new_tokens=None,
|
||||
early_stopping=False,
|
||||
stop_strings=["</answer>"]
|
||||
)
|
||||
)
|
||||
elif args.backend == "vllm":
|
||||
|
@ -57,12 +59,13 @@ if __name__ == "__main__":
|
|||
)
|
||||
generate_config.update(
|
||||
dict(
|
||||
max_tokens=2048,
|
||||
max_tokens=512,
|
||||
ignore_eos=True,
|
||||
include_stop_str_in_output=True,
|
||||
stop=["</answer>"],
|
||||
temperature=0.7,
|
||||
temperature=0.5,
|
||||
top_p=0.95,
|
||||
n=1,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue