mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish applications/Chat/coati/models/base/actor.py code style (#4248)
parent
915ed8bed1
commit
77c469e1ba
|
@ -21,16 +21,13 @@ class Actor(LoRAModule):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.convert_to_lora()
|
self.convert_to_lora()
|
||||||
|
|
||||||
def forward(self,
|
def forward(
|
||||||
input_ids: torch.LongTensor,
|
self,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
input_ids: torch.LongTensor,
|
||||||
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
**model_kwargs, # HACK: `generate` method may pass more kwargs
|
||||||
|
) -> torch.Tensor:
|
||||||
"""Returns model output.
|
"""Returns model output.
|
||||||
"""
|
"""
|
||||||
output = self.model(
|
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
|
||||||
input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
**model_kwargs
|
|
||||||
)
|
|
||||||
return output
|
return output
|
||||||
|
|
Loading…
Reference in New Issue