mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Nvidia provider support for OpenAI API endpoints
This wires up the openai_completion and openai_chat_completion API methods for the remote Nvidia inference provider, and adds it to the chat completions part of the OpenAI test suite. The hosted Nvidia service doesn't actually host any Llama models with functioning completions and chat completions endpoints, so for now the test suite only activates the nvidia provider for chat completions. Signed-off-by: Ben Browning <bbrownin@redhat.com>
This commit is contained in:
parent
8f5cd49159
commit
a5827f7cb3
2 changed files with 143 additions and 15 deletions
|
@ -7,7 +7,7 @@
|
|||
import logging
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from typing import AsyncIterator, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||
|
||||
|
@ -35,15 +35,15 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
)
|
||||
from llama_stack.apis.inference.inference import OpenAIChatCompletion, OpenAICompletion, OpenAIMessageParam
|
||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
convert_openai_chat_completion_choice,
|
||||
convert_openai_chat_completion_stream,
|
||||
prepare_openai_completion_params,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||
|
||||
|
@ -60,12 +60,7 @@ from .utils import _is_nvidia_hosted
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NVIDIAInferenceAdapter(
|
||||
Inference,
|
||||
OpenAIChatCompletionUnsupportedMixin,
|
||||
OpenAICompletionUnsupportedMixin,
|
||||
ModelRegistryHelper,
|
||||
):
|
||||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||
def __init__(self, config: NVIDIAConfig) -> None:
|
||||
# TODO(mf): filter by available models
|
||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||
|
@ -270,3 +265,111 @@ class NVIDIAInferenceAdapter(
|
|||
else:
|
||||
# we pass n=1 to get only one completion
|
||||
return convert_openai_chat_completion_choice(response.choices[0])
|
||||
|
||||
async def openai_completion(
|
||||
self,
|
||||
model: str,
|
||||
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||
best_of: Optional[int] = None,
|
||||
echo: Optional[bool] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
guided_choice: Optional[List[str]] = None,
|
||||
prompt_logprobs: Optional[int] = None,
|
||||
) -> OpenAICompletion:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
prompt=prompt,
|
||||
best_of=best_of,
|
||||
echo=echo,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
presence_penalty=presence_penalty,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
async def openai_chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[OpenAIMessageParam],
|
||||
frequency_penalty: Optional[float] = None,
|
||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
functions: Optional[List[Dict[str, Any]]] = None,
|
||||
logit_bias: Optional[Dict[str, float]] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
n: Optional[int] = None,
|
||||
parallel_tool_calls: Optional[bool] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
response_format: Optional[Dict[str, str]] = None,
|
||||
seed: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
stream: Optional[bool] = None,
|
||||
stream_options: Optional[Dict[str, Any]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||
tools: Optional[List[Dict[str, Any]]] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
user: Optional[str] = None,
|
||||
) -> OpenAIChatCompletion:
|
||||
provider_model_id = self.get_provider_model_id(model)
|
||||
|
||||
params = await prepare_openai_completion_params(
|
||||
model=provider_model_id,
|
||||
messages=messages,
|
||||
frequency_penalty=frequency_penalty,
|
||||
function_call=function_call,
|
||||
functions=functions,
|
||||
logit_bias=logit_bias,
|
||||
logprobs=logprobs,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
n=n,
|
||||
parallel_tool_calls=parallel_tool_calls,
|
||||
presence_penalty=presence_penalty,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
stream_options=stream_options,
|
||||
temperature=temperature,
|
||||
tool_choice=tool_choice,
|
||||
tools=tools,
|
||||
top_logprobs=top_logprobs,
|
||||
top_p=top_p,
|
||||
user=user,
|
||||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
|
|
@ -33,6 +33,9 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
"remote::bedrock",
|
||||
"remote::cerebras",
|
||||
"remote::databricks",
|
||||
# Technically Nvidia does support OpenAI completions, but none of their hosted models
|
||||
# support both completions and chat completions endpoint and all the Llama models are
|
||||
# just chat completions
|
||||
"remote::nvidia",
|
||||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
|
@ -41,6 +44,25 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
|||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||
|
||||
|
||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||
if isinstance(client_with_models, LlamaStackAsLibraryClient):
|
||||
pytest.skip("OpenAI chat completions are not supported when testing with library client yet.")
|
||||
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type in (
|
||||
"inline::meta-reference",
|
||||
"inline::sentence-transformers",
|
||||
"inline::vllm",
|
||||
"remote::bedrock",
|
||||
"remote::cerebras",
|
||||
"remote::databricks",
|
||||
"remote::runpod",
|
||||
"remote::sambanova",
|
||||
"remote::tgi",
|
||||
):
|
||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")
|
||||
|
||||
|
||||
def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
||||
provider = provider_from_model(client_with_models, model_id)
|
||||
if provider.provider_type != "remote::vllm":
|
||||
|
@ -48,8 +70,7 @@ def skip_if_provider_isnt_vllm(client_with_models, model_id):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_client(client_with_models, text_model_id):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
def openai_client(client_with_models):
|
||||
base_url = f"{client_with_models.base_url}/v1/openai/v1"
|
||||
return OpenAI(base_url=base_url, api_key="bar")
|
||||
|
||||
|
@ -60,7 +81,8 @@ def openai_client(client_with_models, text_model_id):
|
|||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_non_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
|
@ -81,7 +103,8 @@ def test_openai_completion_non_streaming(openai_client, text_model_id, test_case
|
|||
"inference:completion:sanity",
|
||||
],
|
||||
)
|
||||
def test_openai_completion_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
|
||||
# ollama needs more verbose prompting for some reason here...
|
||||
|
@ -145,7 +168,8 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
|||
"inference:chat_completion:non_streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
@ -172,7 +196,8 @@ def test_openai_chat_completion_non_streaming(openai_client, text_model_id, test
|
|||
"inference:chat_completion:streaming_02",
|
||||
],
|
||||
)
|
||||
def test_openai_chat_completion_streaming(openai_client, text_model_id, test_case):
|
||||
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||
tc = TestCase(test_case)
|
||||
question = tc["question"]
|
||||
expected = tc["expected"]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue