diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 3ed458058..d6f717719 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,7 @@ import logging import warnings from functools import lru_cache -from typing import AsyncIterator, List, Optional, Union +from typing import Any, AsyncIterator, Dict, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -35,15 +35,15 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ) +from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam from llama_stack.models.llama.datatypes import ToolPromptFormat from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionUnsupportedMixin, - OpenAICompletionUnsupportedMixin, convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, + prepare_openai_completion_params, ) from llama_stack.providers.utils.inference.prompt_adapter import content_has_media @@ -60,12 +60,7 @@ from .utils import _is_nvidia_hosted logger = logging.getLogger(__name__) -class NVIDIAInferenceAdapter( - Inference, - OpenAIChatCompletionUnsupportedMixin, - OpenAICompletionUnsupportedMixin, - ModelRegistryHelper, -): +class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -270,3 +265,111 @@ class NVIDIAInferenceAdapter( else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + provider_model_id = self.get_provider_model_id(model) + + params = await prepare_openai_completion_params( + model=provider_model_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + + try: + return await self._get_client(provider_model_id).completions.create(**params) + except APIConnectionError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[Dict[str, str]] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> OpenAIChatCompletion: + provider_model_id = self.get_provider_model_id(model) + + params = await prepare_openai_completion_params( + model=provider_model_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + + try: + return await self._get_client(provider_model_id).chat.completions.create(**params) + except APIConnectionError as e: + raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index d94390b8f..e6e584727 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -33,6 +33,9 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) "remote::bedrock", "remote::cerebras", "remote::databricks", + # Technically Nvidia does support OpenAI completions, but none of their hosted models + # support both completions and chat completions endpoint and all the Llama models are + # just chat completions "remote::nvidia", "remote::runpod", "remote::sambanova", @@ -41,6 +44,25 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") +def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): + if isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI chat completions are not supported when testing with library client yet.") + + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "inline::meta-reference", + "inline::sentence-transformers", + "inline::vllm", + "remote::bedrock", + "remote::cerebras", + "remote::databricks", + "remote::runpod", + "remote::sambanova", + "remote::tgi", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") + + def skip_if_provider_isnt_vllm(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) if provider.provider_type != "remote::vllm": @@ -48,8 +70,7 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id): @pytest.fixture -def openai_client(client_with_models, text_model_id): - skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) +def openai_client(client_with_models): base_url = f"{client_with_models.base_url}/v1/openai/v1" return OpenAI(base_url=base_url, api_key="bar") @@ -60,7 +81,8 @@ def openai_client(client_with_models, text_model_id): "inference:completion:sanity", ], ) -def test_openai_completion_non_streaming(openai_client, text_model_id, test_case): +def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... @@ -81,7 +103,8 @@ def test_openai_completion_non_streaming(openai_client, text_model_id, test_case "inference:completion:sanity", ], ) -def test_openai_completion_streaming(openai_client, text_model_id, test_case): +def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... @@ -145,7 +168,8 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text "inference:chat_completion:non_streaming_02", ], ) -def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test_case): +def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] @@ -172,7 +196,8 @@ def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test "inference:chat_completion:streaming_02", ], ) -def test_openai_chat_completion_streaming(openai_client, text_model_id, test_case): +def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case): + skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"]