diff --git a/ptuning/main.py b/ptuning/main.py index b027e9e..6328eac 100644 --- a/ptuning/main.py +++ b/ptuning/main.py @@ -166,8 +166,8 @@ def main(): else: prompt = "" history = examples[history_column][i] - for i, (old_query, response) in enumerate(history): - prompt += "[Round {}]\n问:{}\n答:{}\n".format(i, old_query, response) + for turn_idx, (old_query, response) in enumerate(history): + prompt += "[Round {}]\n问:{}\n答:{}\n".format(turn_idx, old_query, response) prompt += "[Round {}]\n问:{}\n答:".format(len(history), query) inputs.append(prompt) targets.append(examples[response_column][i])