Nvidia provider support for OpenAI API endpoints

This wires up the openai_completion and openai_chat_completion API
methods for the remote Nvidia inference provider, and adds it to the
chat completions part of the OpenAI test suite.

The hosted Nvidia service doesn't actually host any Llama models with
functioning completions and chat completions endpoints, so for now the
test suite only activates the nvidia provider for chat completions.

Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
Ben Browning 2025-04-10 13:43:28 -04:00
parent 8f5cd49159
commit a5827f7cb3
2 changed files with 143 additions and 15 deletions

View file

@ -33,6 +33,9 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
"remote::bedrock",
"remote::cerebras",
"remote::databricks",
# Technically Nvidia does support OpenAI completions, but none of their hosted models
# support both completions and chat completions endpoint and all the Llama models are
# just chat completions
"remote::nvidia",
"remote::runpod",
"remote::sambanova",
@ -41,6 +44,25 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
if isinstance(client_with_models, LlamaStackAsLibraryClient):
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"inline::meta-reference",
"inline::sentence-transformers",
"inline::vllm",
"remote::bedrock",
"remote::cerebras",
"remote::databricks",
"remote::runpod",
"remote::sambanova",
"remote::tgi",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
def skip_if_provider_isnt_vllm(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type != "remote::vllm":
@ -48,8 +70,7 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id):
@pytest.fixture
def openai_client(client_with_models, text_model_id):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
def openai_client(client_with_models):
base_url = f"{client_with_models.base_url}/v1/openai/v1"
return OpenAI(base_url=base_url, api_key="bar")
@ -60,7 +81,8 @@ def openai_client(client_with_models, text_model_id):
"inference:completion:sanity",
],
)
def test_openai_completion_non_streaming(openai_client, text_model_id, test_case):
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
# ollama needs more verbose prompting for some reason here...
@ -81,7 +103,8 @@ def test_openai_completion_non_streaming(openai_client, text_model_id, test_case
"inference:completion:sanity",
],
)
def test_openai_completion_streaming(openai_client, text_model_id, test_case):
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
# ollama needs more verbose prompting for some reason here...
@ -145,7 +168,8 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
"inference:chat_completion:non_streaming_02",
],
)
def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test_case):
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
@ -172,7 +196,8 @@ def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming(openai_client, text_model_id, test_case):
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]