feat: add batch inference API to llama stack inference (#1945)

# What does this PR do?

This PR adds two methods to the Inference API:
- `batch_completion`
- `batch_chat_completion`

The motivation is for evaluations targeting a local inference engine
(like meta-reference or vllm) where batch APIs provide for a substantial
amount of acceleration.

Why did I not add this to `Api.batch_inference` though? That just
resulted in a _lot_ more book-keeping given the structure of Llama
Stack. Had I done that, I would have needed to create a notion of a
"batch model" resource, setup routing based on that, etc. This does not
sound ideal.

So what's the future of the batch inference API? I am not sure. Maybe we
can keep it for true _asynchronous_ execution. So you can submit
requests, and it can return a Job instance, etc.

## Test Plan

Run meta-reference-gpu using:
```bash
export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct
export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000
export MODEL_PARALLEL_SIZE=4
export MAX_BATCH_SIZE=32
export MAX_SEQ_LEN=6144

LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu
```

Then run the batch inference test case.
This commit is contained in:
Ashwin Bharambe 2025-04-12 11:41:12 -07:00 committed by GitHub
parent 854c2ad264
commit f34f22f8c7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 698 additions and 389 deletions

View file

@ -6,7 +6,7 @@
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Generator
from typing import Any, Callable, Generator, List
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
@ -23,13 +23,13 @@ class ModelRunner:
self.llama = llama
# the `task` object is the same that is sent to `ModelParallelProcessGroup.run_inference()`
def __call__(self, req: Any):
if isinstance(req, ChatCompletionRequestWithRawContent):
return self.llama.chat_completion(req)
elif isinstance(req, CompletionRequestWithRawContent):
return self.llama.completion(req)
def __call__(self, task: Any):
if task[0] == "chat_completion":
return self.llama.chat_completion(task[1])
elif task[0] == "completion":
return self.llama.completion(task[1])
else:
raise ValueError(f"Unexpected task type {type(req)}")
raise ValueError(f"Unexpected task type {task[0]}")
def init_model_cb(
@ -82,16 +82,16 @@ class LlamaModelParallelGenerator:
def completion(
self,
request: CompletionRequestWithRawContent,
request_batch: List[CompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("completion", req_obj))
yield from gen
def chat_completion(
self,
request: ChatCompletionRequestWithRawContent,
request_batch: List[ChatCompletionRequestWithRawContent],
) -> Generator:
req_obj = deepcopy(request)
gen = self.group.run_inference(req_obj)
req_obj = deepcopy(request_batch)
gen = self.group.run_inference(("chat_completion", req_obj))
yield from gen