mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
# What does this PR do? This stubs in some OpenAI server-side compatibility with three new endpoints: /v1/openai/v1/models /v1/openai/v1/completions /v1/openai/v1/chat/completions This gives common inference apps using OpenAI clients the ability to talk to Llama Stack using an endpoint like http://localhost:8321/v1/openai/v1 . The two "v1" instances in there isn't awesome, but the thinking is that Llama Stack's API is v1 and then our OpenAI compatibility layer is compatible with OpenAI V1. And, some OpenAI clients implicitly assume the URL ends with "v1", so this gives maximum compatibility. The openai models endpoint is implemented in the routing layer, and just returns all the models Llama Stack knows about. The following providers should be working with the new OpenAI completions and chat/completions API: * remote::anthropic (untested) * remote::cerebras-openai-compat (untested) * remote::fireworks (tested) * remote::fireworks-openai-compat (untested) * remote::gemini (untested) * remote::groq-openai-compat (untested) * remote::nvidia (tested) * remote::ollama (tested) * remote::openai (untested) * remote::passthrough (untested) * remote::sambanova-openai-compat (untested) * remote::together (tested) * remote::together-openai-compat (untested) * remote::vllm (tested) The goal to support this for every inference provider - proxying directly to the provider's OpenAI endpoint for OpenAI-compatible providers. For providers that don't have an OpenAI-compatible API, we'll add a mixin to translate incoming OpenAI requests to Llama Stack inference requests and translate the Llama Stack inference responses to OpenAI responses. This is related to #1817 but is a bit larger in scope than just chat completions, as I have real use-cases that need the older completions API as well. ## Test Plan ### vLLM ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack build --template remote-vllm --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "meta-llama/Llama-3.2-3B-Instruct" ``` ### ollama ``` INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" llama stack build --template ollama --image-type venv --run LLAMA_STACK_CONFIG=http://localhost:8321 INFERENCE_MODEL="llama3.2:3b-instruct-q8_0" python -m pytest -v tests/integration/inference/test_openai_completion.py --text-model "llama3.2:3b-instruct-q8_0" ``` ## Documentation Run a Llama Stack distribution that uses one of the providers mentioned in the list above. Then, use your favorite OpenAI client to send completion or chat completion requests with the base_url set to http://localhost:8321/v1/openai/v1 . Replace "localhost:8321" with the host and port of your Llama Stack server, if different. --------- Signed-off-by: Ben Browning <bbrownin@redhat.com>
216 lines
7.2 KiB
Python
216 lines
7.2 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
|
|
import pytest
|
|
from openai import OpenAI
|
|
|
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
|
|
|
from ..test_cases.test_case import TestCase
|
|
|
|
|
|
def provider_from_model(client_with_models, model_id):
|
|
models = {m.identifier: m for m in client_with_models.models.list()}
|
|
models.update({m.provider_resource_id: m for m in client_with_models.models.list()})
|
|
provider_id = models[model_id].provider_id
|
|
providers = {p.provider_id: p for p in client_with_models.providers.list()}
|
|
return providers[provider_id]
|
|
|
|
|
|
def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id):
|
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
|
pytest.skip("OpenAI completions are not supported when testing with library client yet.")
|
|
|
|
provider = provider_from_model(client_with_models, model_id)
|
|
if provider.provider_type in (
|
|
"inline::meta-reference",
|
|
"inline::sentence-transformers",
|
|
"inline::vllm",
|
|
"remote::bedrock",
|
|
"remote::cerebras",
|
|
"remote::databricks",
|
|
# Technically Nvidia does support OpenAI completions, but none of their hosted models
|
|
# support both completions and chat completions endpoint and all the Llama models are
|
|
# just chat completions
|
|
"remote::nvidia",
|
|
"remote::runpod",
|
|
"remote::sambanova",
|
|
"remote::tgi",
|
|
):
|
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
|
|
|
|
|
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
|
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
|
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
|
|
|
provider = provider_from_model(client_with_models, model_id)
|
|
if provider.provider_type in (
|
|
"inline::meta-reference",
|
|
"inline::sentence-transformers",
|
|
"inline::vllm",
|
|
"remote::bedrock",
|
|
"remote::cerebras",
|
|
"remote::databricks",
|
|
"remote::runpod",
|
|
"remote::sambanova",
|
|
"remote::tgi",
|
|
):
|
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
|
|
|
|
|
|
def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
|
provider = provider_from_model(client_with_models, model_id)
|
|
if provider.provider_type != "remote::vllm":
|
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support vllm extra_body parameters.")
|
|
|
|
|
|
@pytest.fixture
|
|
def openai_client(client_with_models):
|
|
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
|
return OpenAI(base_url=base_url, api_key="bar")
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:sanity",
|
|
],
|
|
)
|
|
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
# ollama needs more verbose prompting for some reason here...
|
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
|
response = openai_client.completions.create(
|
|
model=text_model_id,
|
|
prompt=prompt,
|
|
stream=False,
|
|
)
|
|
assert len(response.choices) > 0
|
|
choice = response.choices[0]
|
|
assert len(choice.text) > 10
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:completion:sanity",
|
|
],
|
|
)
|
|
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
|
|
# ollama needs more verbose prompting for some reason here...
|
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
|
response = openai_client.completions.create(
|
|
model=text_model_id,
|
|
prompt=prompt,
|
|
stream=True,
|
|
max_tokens=50,
|
|
)
|
|
streamed_content = [chunk.choices[0].text for chunk in response]
|
|
content_str = "".join(streamed_content).lower().strip()
|
|
assert len(content_str) > 10
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"prompt_logprobs",
|
|
[
|
|
1,
|
|
0,
|
|
],
|
|
)
|
|
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
|
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
|
|
|
prompt = "Hello, world!"
|
|
response = openai_client.completions.create(
|
|
model=text_model_id,
|
|
prompt=prompt,
|
|
stream=False,
|
|
extra_body={
|
|
"prompt_logprobs": prompt_logprobs,
|
|
},
|
|
)
|
|
assert len(response.choices) > 0
|
|
choice = response.choices[0]
|
|
assert len(choice.prompt_logprobs) > 0
|
|
|
|
|
|
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
|
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
|
|
|
prompt = "I am feeling really sad today."
|
|
response = openai_client.completions.create(
|
|
model=text_model_id,
|
|
prompt=prompt,
|
|
stream=False,
|
|
extra_body={
|
|
"guided_choice": ["joy", "sadness"],
|
|
},
|
|
)
|
|
assert len(response.choices) > 0
|
|
choice = response.choices[0]
|
|
assert choice.text in ["joy", "sadness"]
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:non_streaming_01",
|
|
"inference:chat_completion:non_streaming_02",
|
|
],
|
|
)
|
|
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
question = tc["question"]
|
|
expected = tc["expected"]
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=text_model_id,
|
|
messages=[
|
|
{
|
|
"role": "user",
|
|
"content": question,
|
|
}
|
|
],
|
|
stream=False,
|
|
)
|
|
message_content = response.choices[0].message.content.lower().strip()
|
|
assert len(message_content) > 0
|
|
assert expected.lower() in message_content
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"test_case",
|
|
[
|
|
"inference:chat_completion:streaming_01",
|
|
"inference:chat_completion:streaming_02",
|
|
],
|
|
)
|
|
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
|
tc = TestCase(test_case)
|
|
question = tc["question"]
|
|
expected = tc["expected"]
|
|
|
|
response = openai_client.chat.completions.create(
|
|
model=text_model_id,
|
|
messages=[{"role": "user", "content": question}],
|
|
stream=True,
|
|
timeout=120, # Increase timeout to 2 minutes for large conversation history
|
|
)
|
|
streamed_content = []
|
|
for chunk in response:
|
|
if chunk.choices[0].delta.content:
|
|
streamed_content.append(chunk.choices[0].delta.content.lower().strip())
|
|
assert len(streamed_content) > 0
|
|
assert expected.lower() in "".join(streamed_content)
|