Hello Folks π€
I have truly
off topic and random question... π
I know it's a
basic inference question but I am willing to know which is the best way to inference in batch.
The overall goal
-> I am using
t5-small
model for
summarization task.
-> And I have around
10
to
15
different paragraphs to be summarized
in a single call.π¨βπ» The code I am using right now
It is generic loop code but I expect some optimization here:
points = [
"summarize: ABC...",
"summarize: CBA...",
"summarize: ERG...",
"summarize: RAG...",
]
summaries = []
for point in points:
input_ids = tokenizer.encode(point, return_tensors="pt", max_length=512, truncation=True)
output_ids = model.generate(input_ids, max_length=256, temperature=0.35, do_sample=True)
summaries.append(tokenizer.decode(output_ids[0], skip_special_tokens=True))
This one
takes time as expected.
π― I have tried this...
# Passsing input ids in batch
ids = tokenizer(points, return_tensors="pt", max_length=512, padding="longest")
response = model.generate(**ids, max_length=256, temperature=0.35, do_sample=True)
tokenizer.batch_decode(response, skip_special_tokens=True)
But I am worried if the model will connect the paragraphs internally and will leak information between each. I am not sure but this way is
significantly fast than the loop way.
π€ My Ask
Am I doing it right? How to perform the batch inference which is so fast and all inputs are NOT talking to each other?
Is there any other way to increase the speed? (Or would I just need to use threading?)
Thanks!