mirror of https://github.com/hpcaitech/ColossalAI
60 lines
2.5 KiB
Python
60 lines
2.5 KiB
Python
|
import torch
|
||
|
from typing import List, Deque, Tuple, Hashable, Any
|
||
|
from energonai import BatchManager, SubmitEntry, TaskEntry
|
||
|
|
||
|
|
||
|
class BatchManagerForGeneration(BatchManager):
|
||
|
def __init__(self, max_batch_size: int = 1, pad_token_id: int = 0) -> None:
|
||
|
super().__init__()
|
||
|
self.max_batch_size = max_batch_size
|
||
|
self.pad_token_id = pad_token_id
|
||
|
|
||
|
def _left_padding(self, batch_inputs):
|
||
|
max_len = max(len(inputs['input_ids']) for inputs in batch_inputs)
|
||
|
outputs = {'input_ids': [], 'attention_mask': []}
|
||
|
for inputs in batch_inputs:
|
||
|
input_ids, attention_mask = inputs['input_ids'], inputs['attention_mask']
|
||
|
padding_len = max_len - len(input_ids)
|
||
|
input_ids = [self.pad_token_id] * padding_len + input_ids
|
||
|
attention_mask = [0] * padding_len + attention_mask
|
||
|
outputs['input_ids'].append(input_ids)
|
||
|
outputs['attention_mask'].append(attention_mask)
|
||
|
for k in outputs:
|
||
|
outputs[k] = torch.tensor(outputs[k])
|
||
|
return outputs, max_len
|
||
|
|
||
|
@staticmethod
|
||
|
def _make_batch_key(entry: SubmitEntry) -> tuple:
|
||
|
data = entry.data
|
||
|
return (data['top_k'], data['top_p'], data['temperature'])
|
||
|
|
||
|
def make_batch(self, q: Deque[SubmitEntry]) -> Tuple[TaskEntry, dict]:
|
||
|
entry = q.popleft()
|
||
|
uids = [entry.uid]
|
||
|
batch = [entry.data]
|
||
|
while len(batch) < self.max_batch_size:
|
||
|
if len(q) == 0:
|
||
|
break
|
||
|
if self._make_batch_key(entry) != self._make_batch_key(q[0]):
|
||
|
break
|
||
|
if q[0].data['max_tokens'] > entry.data['max_tokens']:
|
||
|
break
|
||
|
e = q.popleft()
|
||
|
batch.append(e.data)
|
||
|
uids.append(e.uid)
|
||
|
inputs, max_len = self._left_padding(batch)
|
||
|
trunc_lens = []
|
||
|
for data in batch:
|
||
|
trunc_lens.append(max_len + data['max_tokens'])
|
||
|
inputs['top_k'] = entry.data['top_k']
|
||
|
inputs['top_p'] = entry.data['top_p']
|
||
|
inputs['temperature'] = entry.data['temperature']
|
||
|
inputs['max_tokens'] = max_len + entry.data['max_tokens']
|
||
|
return TaskEntry(tuple(uids), inputs), {'trunc_lens': trunc_lens}
|
||
|
|
||
|
def split_batch(self, task_entry: TaskEntry, trunc_lens: List[int] = []) -> List[Tuple[Hashable, Any]]:
|
||
|
retval = []
|
||
|
for uid, output, trunc_len in zip(task_entry.uids, task_entry.batch, trunc_lens):
|
||
|
retval.append((uid, output[:trunc_len]))
|
||
|
return retval
|