[NFC] polish applications/Chat/coati/models/generation.py code style (#4275)

pull/4338/head
RichardoLuo 2023-07-18 18:04:02 +08:00 committed by binmakeswell
parent dc1b6127f9
commit 709e121cd5
1 changed files with 6 additions and 7 deletions

View File

@ -5,7 +5,6 @@ import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
try:
from transformers.generation_logits_process import (
LogitsProcessorList,
@ -148,12 +147,12 @@ def generate(model: nn.Module,
@torch.no_grad()
def generate_with_actor(actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
def generate_with_actor(
actor_model: nn.Module,
input_ids: torch.Tensor,
return_action_mask: bool = True,
**kwargs
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details.
"""
# generate sequences