mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/coati/models/generation.py code style (#4275)
parent
dc1b6127f9
commit
709e121cd5
|
@ -5,7 +5,6 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
|
@ -148,12 +147,12 @@ def generate(model: nn.Module,
|
|||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_with_actor(actor_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
|
||||
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
def generate_with_actor(
|
||||
actor_model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||
"""
|
||||
# generate sequences
|
||||
|
|
Loading…
Reference in New Issue