mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/coati/models/generation.py code style (#4275)
parent
dc1b6127f9
commit
709e121cd5
|
@ -5,7 +5,6 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from transformers.generation_logits_process import (
|
from transformers.generation_logits_process import (
|
||||||
LogitsProcessorList,
|
LogitsProcessorList,
|
||||||
|
@ -148,12 +147,12 @@ def generate(model: nn.Module,
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def generate_with_actor(actor_model: nn.Module,
|
def generate_with_actor(
|
||||||
input_ids: torch.Tensor,
|
actor_model: nn.Module,
|
||||||
return_action_mask: bool = True,
|
input_ids: torch.Tensor,
|
||||||
**kwargs
|
return_action_mask: bool = True,
|
||||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor],
|
**kwargs
|
||||||
Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||||
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
"""Generate token sequence with actor model. Refer to `generate` for more details.
|
||||||
"""
|
"""
|
||||||
# generate sequences
|
# generate sequences
|
||||||
|
|
Loading…
Reference in New Issue