diff --git a/applications/Chat/coati/models/generation.py b/applications/Chat/coati/models/generation.py index 0156e2284..d96ad78a8 100644 --- a/applications/Chat/coati/models/generation.py +++ b/applications/Chat/coati/models/generation.py @@ -5,7 +5,6 @@ import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F - try: from transformers.generation_logits_process import ( LogitsProcessorList, @@ -148,12 +147,12 @@ def generate(model: nn.Module, @torch.no_grad() -def generate_with_actor(actor_model: nn.Module, - input_ids: torch.Tensor, - return_action_mask: bool = True, - **kwargs - ) -> Union[Tuple[torch.LongTensor, torch.LongTensor], - Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: +def generate_with_actor( + actor_model: nn.Module, + input_ids: torch.Tensor, + return_action_mask: bool = True, + **kwargs +) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]: """Generate token sequence with actor model. Refer to `generate` for more details. """ # generate sequences