mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
feat: add batch inference API to llama stack inference (#1945)
# What does this PR do? This PR adds two methods to the Inference API: - `batch_completion` - `batch_chat_completion` The motivation is for evaluations targeting a local inference engine (like meta-reference or vllm) where batch APIs provide for a substantial amount of acceleration. Why did I not add this to `Api.batch_inference` though? That just resulted in a _lot_ more book-keeping given the structure of Llama Stack. Had I done that, I would have needed to create a notion of a "batch model" resource, setup routing based on that, etc. This does not sound ideal. So what's the future of the batch inference API? I am not sure. Maybe we can keep it for true _asynchronous_ execution. So you can submit requests, and it can return a Job instance, etc. ## Test Plan Run meta-reference-gpu using: ```bash export INFERENCE_MODEL=meta-llama/Llama-4-Scout-17B-16E-Instruct export INFERENCE_CHECKPOINT_DIR=../checkpoints/Llama-4-Scout-17B-16E-Instruct-20250331210000 export MODEL_PARALLEL_SIZE=4 export MAX_BATCH_SIZE=32 export MAX_SEQ_LEN=6144 LLAMA_MODELS_DEBUG=1 llama stack run meta-reference-gpu ``` Then run the batch inference test case.
This commit is contained in:
parent
854c2ad264
commit
f34f22f8c7
23 changed files with 698 additions and 389 deletions
76
tests/integration/inference/test_batch_inference.py
Normal file
76
tests/integration/inference/test_batch_inference.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from ..test_cases.test_case import TestCase
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_batch_inference(client_with_models, model_id):
|
||||
models = {m.identifier: m for m in client_with_models.models.list()}
|
||||
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
||||
provider_id = models[model_id].provider_id
|
||||
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
||||
provider = providers[provider_id]
|
||||
if provider.provider_type not in ("inline::meta-reference",):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support batch inference")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:completion:batch_completion",
|
||||
],
|
||||
)
|
||||
def test_batch_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
content_batch = tc["contents"]
|
||||
response = client_with_models.inference.batch_completion(
|
||||
content_batch=content_batch,
|
||||
model_id=text_model_id,
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
assert len(response.batch) == len(content_batch)
|
||||
for i, r in enumerate(response.batch):
|
||||
print(f"response {i}: {r.content}")
|
||||
assert len(r.content) > 10
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_case",
|
||||
[
|
||||
"inference:chat_completion:batch_completion",
|
||||
],
|
||||
)
|
||||
def test_batch_chat_completion_non_streaming(client_with_models, text_model_id, test_case):
|
||||
skip_if_provider_doesnt_support_batch_inference(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
qa_pairs = tc["qa_pairs"]
|
||||
|
||||
message_batch = [
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": qa["question"],
|
||||
}
|
||||
]
|
||||
for qa in qa_pairs
|
||||
]
|
||||
|
||||
response = client_with_models.inference.batch_chat_completion(
|
||||
messages_batch=message_batch,
|
||||
model_id=text_model_id,
|
||||
)
|
||||
assert len(response.batch) == len(qa_pairs)
|
||||
for i, r in enumerate(response.batch):
|
||||
print(f"response {i}: {r.completion_message.content}")
|
||||
assert len(r.completion_message.content) > 0
|
||||
assert qa_pairs[i]["answer"].lower() in r.completion_message.content.lower()
|
|
@ -537,5 +537,31 @@
|
|||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"batch_completion": {
|
||||
"data": {
|
||||
"qa_pairs": [
|
||||
{
|
||||
"question": "What is the capital of France?",
|
||||
"answer": "Paris"
|
||||
},
|
||||
{
|
||||
"question": "Who wrote the book '1984'?",
|
||||
"answer": "George Orwell"
|
||||
},
|
||||
{
|
||||
"question": "Which planet has rings around it with a name starting with letter S?",
|
||||
"answer": "Saturn"
|
||||
},
|
||||
{
|
||||
"question": "When did the first moon landing happen?",
|
||||
"answer": "1969"
|
||||
},
|
||||
{
|
||||
"question": "What word says 'hello' in Spanish?",
|
||||
"answer": "Hola"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,5 +44,18 @@
|
|||
"year_retired": "2003"
|
||||
}
|
||||
}
|
||||
},
|
||||
"batch_completion": {
|
||||
"data": {
|
||||
"contents": [
|
||||
"Micheael Jordan is born in ",
|
||||
"Roses are red, violets are ",
|
||||
"If you had a million dollars, what would you do with it? ",
|
||||
"All you need is ",
|
||||
"The capital of France is ",
|
||||
"It is a good day to ",
|
||||
"The answer to the universe is "
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue