mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 03:29:59 +00:00
Nvidia provider support for OpenAI API endpoints
This wires up the openai_completion and openai_chat_completion API methods for the remote Nvidia inference provider, and adds it to the chat completions part of the OpenAI test suite. The hosted Nvidia service doesn't actually host any Llama models with functioning completions and chat completions endpoints, so for now the test suite only activates the nvidia provider for chat completions. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
8f5cd49159
commit
a5827f7cb3
2 changed files with 143 additions and 15 deletions
|
|
@ -33,6 +33,9 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
"remote::bedrock",
|
||||
"remote::cerebras",
|
||||
"remote::databricks",
|
||||
# Technically Nvidia does support OpenAI completions, but none of their hosted models
|
||||
# support both completions and chat completions endpoint and all the Llama models are
|
||||
# just chat completions
|
||||
"remote::nvidia",
|
||||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
|
|
@ -41,6 +44,25 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
"inline::meta-reference",
|
||||
"inline::sentence-transformers",
|
||||
"inline::vllm",
|
||||
"remote::bedrock",
|
||||
"remote::cerebras",
|
||||
"remote::databricks",
|
||||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
|
||||
|
||||
|
||||
def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type != "remote::vllm":
|
||||
|
|
@ -48,8 +70,7 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models, text_model_id):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="bar")
|
||||
|
||||
|
|
@ -60,7 +81,8 @@ def openai_client(client_with_models, text_model_id):
|
|||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_non_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
|
|
@ -81,7 +103,8 @@ def test_openai_completion_non_streaming(openai_client, text_model_id, test_case
|
|||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
|
|
@ -145,7 +168,8 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
|||
"inference:chat_completion:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
|
@ -172,7 +196,8 @@ def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test
|
|||
"inference:chat_completion:streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue