From a5827f7cb33916ba69b088e531ee27ae3960c9da Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 10 Apr 2025 13:43:28 -0400 Subject: [PATCH] Nvidia provider support for OpenAI API endpoints This wires up the openai_completion and openai_chat_completion API methods for the remote Nvidia inference provider, and adds it to the chat completions part of the OpenAI test suite. The hosted Nvidia service doesn't actually host any Llama models with functioning completions and chat completions endpoints, so for now the test suite only activates the nvidia provider for chat completions. Signed-off-by: Ben Browning --- .../remote/inference/nvidia/nvidia.py | 121 ++++++++++++++++-- .../inference/test_openai_completion.py | 37 +++++- 2 files changed, 143 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 3ed458058..d6f717719 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,7 @@ import logging import warnings from functools import lru_cache -from typing import AsyncIterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -35,15 +35,15 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionUnsupportedMixin, - OpenAICompletionUnsupportedMixin, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, + prepare_openai_completion_params, ) from llama_stack.providers.utils.inference.prompt_adapter import content_has_media @@ -60,12 +60,7 @@ from .utils import _is_nvidia_hosted logger = logging.getLogger(__name__) -class NVIDIAInferenceAdapter( - Inference, - OpenAIChatCompletionUnsupportedMixin, - OpenAICompletionUnsupportedMixin, - ModelRegistryHelper, -): +class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -270,3 +265,111 @@ class NVIDIAInferenceAdapter( else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + provider_model_id = self.get_provider_model_id(model) + + params = await prepare_openai_completion_params( + model=provider_model_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + + try: + return await self._get_client(provider_model_id).completions.create(**params) + except APIConnectionError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + provider_model_id = self.get_provider_model_id(model) + + params = await prepare_openai_completion_params( + model=provider_model_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + + try: + return await self._get_client(provider_model_id).chat.completions.create(**params) + except APIConnectionError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index d94390b8f..e6e584727 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -33,6 +33,9 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) "remote::bedrock", "remote::cerebras", "remote::databricks", + # Technically Nvidia does support OpenAI completions, but none of their hosted models + # support both completions and chat completions endpoint and all the Llama models are + # just chat completions "remote::nvidia", "remote::runpod", "remote::sambanova", @@ -41,6 +44,25 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") +def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI chat completions are not supported when testing with library client yet.") + + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "inline::meta-reference", + "inline::sentence-transformers", + "inline::vllm", + "remote::bedrock", + "remote::cerebras", + "remote::databricks", + "remote::runpod", + "remote::sambanova", + "remote::tgi", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") + + def skip_if_provider_isnt_vllm(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) if provider.provider_type != "remote::vllm": @@ -48,8 +70,7 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id): @pytest.fixture -def openai_client(client_with_models, text_model_id): - skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) +def openai_client(client_with_models): base_url = f"{client_with_models.base_url}/v1/openai/v1" return OpenAI(base_url=base_url, api_key="bar") @@ -60,7 +81,8 @@ def openai_client(client_with_models, text_model_id): "inference:completion:sanity", ], ) -def test_openai_completion_non_streaming(openai_client, text_model_id, test_case): +def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... @@ -81,7 +103,8 @@ def test_openai_completion_non_streaming(openai_client, text_model_id, test_case "inference:completion:sanity", ], ) -def test_openai_completion_streaming(openai_client, text_model_id, test_case): +def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... @@ -145,7 +168,8 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text "inference:chat_completion:non_streaming_02", ], ) -def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test_case): +def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] @@ -172,7 +196,8 @@ def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test "inference:chat_completion:streaming_02", ], ) -def test_openai_chat_completion_streaming(openai_client, text_model_id, test_case): +def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"]