[NFC] polish applications/Chat/coati/models/generation.py code style (#4275)

pull/4338/head
RichardoLuo 2023-07-18 18:04:02 +08:00 committed by binmakeswell
parent dc1b6127f9
commit 709e121cd5
1 changed files with 6 additions and 7 deletions

View File

@ -5,7 +5,6 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
try: try:
from transformers.generation_logits_process import ( from transformers.generation_logits_process import (
LogitsProcessorList, LogitsProcessorList,
@ -148,12 +147,12 @@ def generate(model: nn.Module,
@torch.no_grad() @torch.no_grad()
def generate_with_actor(actor_model: nn.Module, def generate_with_actor(
input_ids: torch.Tensor, actor_model: nn.Module,
return_action_mask: bool = True, input_ids: torch.Tensor,
**kwargs return_action_mask: bool = True,
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], **kwargs
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
"""Generate token sequence with actor model. Refer to `generate` for more details. """Generate token sequence with actor model. Refer to `generate` for more details.
""" """
# generate sequences # generate sequences