From e1ed1527795170c9f14eb43d1ba163926eb148a8 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 23 Jul 2025 06:49:40 -0400 Subject: [PATCH] chore: create OpenAIMixin for inference providers with an OpenAI-compat API that need to implement openai_* methods (#2835) # What does this PR do? add an `OpenAIMixin` for use by inference providers who remote endpoints support an OpenAI compatible API. use is demonstrated by refactoring - OpenAIInferenceAdapter - NVIDIAInferenceAdapter (adds embedding support) - LlamaCompatInferenceAdapter ## Test Plan existing unit and integration tests --- docs/source/contributing/new_api_provider.md | 35 +++ .../inference/llama_openai_compat/llama.py | 43 ++- .../remote/inference/nvidia/nvidia.py | 191 ++---------- .../remote/inference/openai/openai.py | 223 ++------------ .../utils/inference/model_registry.py | 6 + .../providers/utils/inference/openai_mixin.py | 272 ++++++++++++++++++ .../test_inference_client_caching.py | 19 +- 7 files changed, 402 insertions(+), 387 deletions(-) create mode 100644 llama_stack/providers/utils/inference/openai_mixin.py diff --git a/docs/source/contributing/new_api_provider.md b/docs/source/contributing/new_api_provider.md index 83058896a..01a8ec093 100644 --- a/docs/source/contributing/new_api_provider.md +++ b/docs/source/contributing/new_api_provider.md @@ -14,6 +14,41 @@ Here are some example PRs to help you get started: - [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355) - [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665) +## Inference Provider Patterns + +When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers. + +### OpenAIMixin + +The `OpenAIMixin` class provides direct OpenAI API functionality for providers that work with OpenAI-compatible endpoints. It includes: + +#### Direct API Methods +- **`openai_completion()`**: Legacy text completion API with full parameter support +- **`openai_chat_completion()`**: Chat completion API supporting streaming, tools, and function calling +- **`openai_embeddings()`**: Text embeddings generation with customizable encoding and dimensions + +#### Model Management +- **`check_model_availability()`**: Queries the API endpoint to verify if a model exists and is accessible + +#### Client Management +- **`client` property**: Automatically creates and configures AsyncOpenAI client instances using your provider's credentials + +#### Required Implementation + +To use `OpenAIMixin`, your provider must implement these abstract methods: + +```python +@abstractmethod +def get_api_key(self) -> str: + """Return the API key for authentication""" + pass + + +@abstractmethod +def get_base_url(self) -> str: + """Return the OpenAI-compatible API base URL""" + pass +``` ## Testing the Provider diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 5f9cb20b2..576080d99 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -5,17 +5,27 @@ # the root directory of this source tree. import logging -from llama_api_client import AsyncLlamaAPIClient, NotFoundError - from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .models import MODEL_ENTRIES logger = logging.getLogger(__name__) -class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): +class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + """ + Llama API Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability() + is used instead of ModelRegistryHelper.check_model_availability(). + + - OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists + - ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning + """ + _config: LlamaCompatConfig def __init__(self, config: LlamaCompatConfig): @@ -28,32 +38,19 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config - async def check_model_availability(self, model: str) -> bool: + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: """ - Check if a specific model is available from Llama API. + Get the base URL for OpenAI mixin. - :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + :return: The Llama API base URL """ - try: - llama_api_client = self._get_llama_api_client() - retrieved_model = await llama_api_client.models.retrieve(model) - logger.info(f"Model {retrieved_model.id} is available from Llama API") - return True - - except NotFoundError: - logger.error(f"Model {model} is not available from Llama API") - return False - - except Exception as e: - logger.error(f"Failed to check model availability from Llama API: {e}") - return False + return self.config.openai_compat_api_base async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() - - def _get_llama_api_client(self) -> AsyncLlamaAPIClient: - return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index cb7554523..7bc3fd0c9 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,9 +7,8 @@ import logging import warnings from collections.abc import AsyncIterator -from typing import Any -from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError +from openai import APIConnectionError, BadRequestError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -28,12 +27,6 @@ from llama_stack.apis.inference import ( Inference, LogProbConfig, Message, - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingsResponse, - OpenAIMessageParam, - OpenAIResponseFormatParam, ResponseFormat, SamplingParams, TextTruncation, @@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_openai_chat_completion_choice, convert_openai_chat_completion_stream, - prepare_openai_completion_params, ) +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from . import NVIDIAConfig @@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted logger = logging.getLogger(__name__) -class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): +class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): + """ + NVIDIA Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability() + is used instead of ModelRegistryHelper.check_model_availability(). It also + must come before Inference to ensure that OpenAIMixin methods are available + in the Inference interface. + + - OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists + - ModelRegistryHelper.check_model_availability() just returns False and shows a warning + """ + def __init__(self, config: NVIDIAConfig) -> None: # TODO(mf): filter by available models ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) @@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config - async def check_model_availability(self, model: str) -> bool: + def get_api_key(self) -> str: """ - Check if a specific model is available. + Get the API key for OpenAI mixin. - :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + :return: The NVIDIA API key """ - try: - await self._client.models.retrieve(model) - return True - except NotFoundError: - logger.error(f"Model {model} is not available") - except Exception as e: - logger.error(f"Failed to check model availability: {e}") - return False + return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY" - @property - def _client(self) -> AsyncOpenAI: + def get_base_url(self) -> str: """ - Returns an OpenAI client for the configured NVIDIA API endpoint. + Get the base URL for OpenAI mixin. - :return: An OpenAI client + :return: The NVIDIA API base URL """ - - base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url - - return AsyncOpenAI( - base_url=base_url, - api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), - timeout=self._config.timeout, - ) - - async def _get_provider_model_id(self, model_id: str) -> str: - if not self.model_store: - raise RuntimeError("Model store is not set") - model = await self.model_store.get_model(model_id) - if model is None: - raise ValueError(f"Model {model_id} is unknown") - return model.provider_model_id + return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url async def completion( self, @@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.completions.create(**request) + response = await self.client.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): extra_body["input_type"] = task_type_options[task_type] try: - response = await self._client.embeddings.create( + response = await self.client.embeddings.create( model=provider_model_id, input=input, extra_body=extra_body, @@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - async def chat_completion( self, model_id: str, @@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.chat.completions.create(**request) + response = await self.client.chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): else: # we pass n=1 to get only one completion return convert_openai_chat_completion_choice(response.choices[0]) - - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - provider_model_id = await self._get_provider_model_id(model) - - params = await prepare_openai_completion_params( - model=provider_model_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - ) - - try: - return await self._client.completions.create(**params) - except APIConnectionError as e: - raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - provider_model_id = await self._get_provider_model_id(model) - - params = await prepare_openai_completion_params( - model=provider_model_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - - try: - return await self._client.chat.completions.create(**params) - except APIConnectionError as e: - raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 7e167f621..9e1b77bde 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -5,23 +5,9 @@ # the root directory of this source tree. import logging -from collections.abc import AsyncIterator -from typing import Any -from openai import AsyncOpenAI, NotFoundError - -from llama_stack.apis.inference import ( - OpenAIChatCompletion, - OpenAIChatCompletionChunk, - OpenAICompletion, - OpenAIEmbeddingData, - OpenAIEmbeddingsResponse, - OpenAIEmbeddingUsage, - OpenAIMessageParam, - OpenAIResponseFormatParam, -) from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin -from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import OpenAIConfig from .models import MODEL_ENTRIES @@ -30,7 +16,7 @@ logger = logging.getLogger(__name__) # -# This OpenAI adapter implements Inference methods using two clients - +# This OpenAI adapter implements Inference methods using two mixins - # # | Inference Method | Implementation Source | # |----------------------------|--------------------------| @@ -39,11 +25,22 @@ logger = logging.getLogger(__name__) # | embedding | LiteLLMOpenAIMixin | # | batch_completion | LiteLLMOpenAIMixin | # | batch_chat_completion | LiteLLMOpenAIMixin | -# | openai_completion | AsyncOpenAI | -# | openai_chat_completion | AsyncOpenAI | -# | openai_embeddings | AsyncOpenAI | +# | openai_completion | OpenAIMixin | +# | openai_chat_completion | OpenAIMixin | +# | openai_embeddings | OpenAIMixin | # -class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): +class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + """ + OpenAI Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability() + is used instead of ModelRegistryHelper.check_model_availability(). + + - OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists + - ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning + """ + def __init__(self, config: OpenAIConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -60,191 +57,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # litellm specific model names, an abstraction leak. self.is_openai_compat = True - async def check_model_availability(self, model: str) -> bool: + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: """ - Check if a specific model is available from OpenAI. + Get the OpenAI API base URL. - :param model: The model identifier to check. - :return: True if the model is available dynamically, False otherwise. + Returns the standard OpenAI API base URL for direct OpenAI API calls. """ - try: - openai_client = self._get_openai_client() - retrieved_model = await openai_client.models.retrieve(model) - logger.info(f"Model {retrieved_model.id} is available from OpenAI") - return True - - except NotFoundError: - logger.error(f"Model {model} is not available from OpenAI") - return False - - except Exception as e: - logger.error(f"Failed to check model availability from OpenAI: {e}") - return False + return "https://api.openai.com/v1" async def initialize(self) -> None: await super().initialize() async def shutdown(self) -> None: await super().shutdown() - - def _get_openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI( - api_key=self.get_api_key(), - ) - - async def openai_completion( - self, - model: str, - prompt: str | list[str] | list[int] | list[list[int]], - best_of: int | None = None, - echo: bool | None = None, - frequency_penalty: float | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_tokens: int | None = None, - n: int | None = None, - presence_penalty: float | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - top_p: float | None = None, - user: str | None = None, - guided_choice: list[str] | None = None, - prompt_logprobs: int | None = None, - suffix: str | None = None, - ) -> OpenAICompletion: - if guided_choice is not None: - logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.") - if prompt_logprobs is not None: - logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") - - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] - params = await prepare_openai_completion_params( - model=model_id, - prompt=prompt, - best_of=best_of, - echo=echo, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - top_p=top_p, - user=user, - suffix=suffix, - ) - return await self._get_openai_client().completions.create(**params) - - async def openai_chat_completion( - self, - model: str, - messages: list[OpenAIMessageParam], - frequency_penalty: float | None = None, - function_call: str | dict[str, Any] | None = None, - functions: list[dict[str, Any]] | None = None, - logit_bias: dict[str, float] | None = None, - logprobs: bool | None = None, - max_completion_tokens: int | None = None, - max_tokens: int | None = None, - n: int | None = None, - parallel_tool_calls: bool | None = None, - presence_penalty: float | None = None, - response_format: OpenAIResponseFormatParam | None = None, - seed: int | None = None, - stop: str | list[str] | None = None, - stream: bool | None = None, - stream_options: dict[str, Any] | None = None, - temperature: float | None = None, - tool_choice: str | dict[str, Any] | None = None, - tools: list[dict[str, Any]] | None = None, - top_logprobs: int | None = None, - top_p: float | None = None, - user: str | None = None, - ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] - params = await prepare_openai_completion_params( - model=model_id, - messages=messages, - frequency_penalty=frequency_penalty, - function_call=function_call, - functions=functions, - logit_bias=logit_bias, - logprobs=logprobs, - max_completion_tokens=max_completion_tokens, - max_tokens=max_tokens, - n=n, - parallel_tool_calls=parallel_tool_calls, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - stream=stream, - stream_options=stream_options, - temperature=temperature, - tool_choice=tool_choice, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - user=user, - ) - return await self._get_openai_client().chat.completions.create(**params) - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - model_id = (await self.model_store.get_model(model)).provider_resource_id - if model_id.startswith("openai/"): - model_id = model_id[len("openai/") :] - - # Prepare parameters for OpenAI embeddings API - params = { - "model": model_id, - "input": input, - } - - if encoding_format is not None: - params["encoding_format"] = encoding_format - if dimensions is not None: - params["dimensions"] = dimensions - if user is not None: - params["user"] = user - - # Call OpenAI embeddings API - response = await self._get_openai_client().embeddings.create(**params) - - data = [] - for i, embedding_data in enumerate(response.data): - data.append( - OpenAIEmbeddingData( - embedding=embedding_data.embedding, - index=i, - ) - ) - - usage = OpenAIEmbeddingUsage( - prompt_tokens=response.usage.prompt_tokens, - total_tokens=response.usage.total_tokens, - ) - - return OpenAIEmbeddingsResponse( - data=data, - model=response.model, - usage=usage, - ) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 801b8ea06..651d58e2a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -10,12 +10,15 @@ from pydantic import BaseModel, Field from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.models import ModelType +from llama_stack.log import get_logger from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) +logger = get_logger(name=__name__, category="core") + # TODO: this class is more confusing than useful right now. We need to make it # more closer to the Model class. @@ -98,6 +101,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate): :param model: The model identifier to check. :return: True if the model is available dynamically, False otherwise. """ + logger.info( + f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default." + ) return False async def register_model(self, model: Model) -> Model: diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py new file mode 100644 index 000000000..72286dffb --- /dev/null +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -0,0 +1,272 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from abc import ABC, abstractmethod +from collections.abc import AsyncIterator +from typing import Any + +import openai +from openai import NOT_GIVEN, AsyncOpenAI + +from llama_stack.apis.inference import ( + Model, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIEmbeddingData, + OpenAIEmbeddingsResponse, + OpenAIEmbeddingUsage, + OpenAIMessageParam, + OpenAIResponseFormatParam, +) +from llama_stack.log import get_logger +from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params + +logger = get_logger(name=__name__, category="core") + + +class OpenAIMixin(ABC): + """ + Mixin class that provides OpenAI-specific functionality for inference providers. + This class handles direct OpenAI API calls using the AsyncOpenAI client. + + This is an abstract base class that requires child classes to implement: + - get_api_key(): Method to retrieve the API key + - get_base_url(): Method to retrieve the OpenAI-compatible API base URL + + Expected Dependencies: + - self.model_store: Injected by the Llama Stack distribution system at runtime. + This provides model registry functionality for looking up registered models. + The model_store is set in routing_tables/common.py during provider initialization. + """ + + @abstractmethod + def get_api_key(self) -> str: + """ + Get the API key. + + This method must be implemented by child classes to provide the API key + for authenticating with the OpenAI API or compatible endpoints. + + :return: The API key as a string + """ + pass + + @abstractmethod + def get_base_url(self) -> str: + """ + Get the OpenAI-compatible API base URL. + + This method must be implemented by child classes to provide the base URL + for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1"). + + :return: The base URL as a string + """ + pass + + @property + def client(self) -> AsyncOpenAI: + """ + Get an AsyncOpenAI client instance. + + Uses the abstract methods get_api_key() and get_base_url() which must be + implemented by child classes. + """ + return AsyncOpenAI( + api_key=self.get_api_key(), + base_url=self.get_base_url(), + ) + + async def _get_provider_model_id(self, model: str) -> str: + """ + Get the provider-specific model ID from the model store. + + This is a utility method that looks up the registered model and returns + the provider_resource_id that should be used for actual API calls. + + :param model: The registered model name/identifier + :return: The provider-specific model ID (e.g., "gpt-4") + """ + # Look up the registered model to get the provider-specific model ID + # self.model_store is injected by the distribution system at runtime + model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined] + # provider_resource_id is str | None, but we expect it to be str for OpenAI calls + if model_obj.provider_resource_id is None: + raise ValueError(f"Model {model} has no provider_resource_id") + return model_obj.provider_resource_id + + async def openai_completion( + self, + model: str, + prompt: str | list[str] | list[int] | list[list[int]], + best_of: int | None = None, + echo: bool | None = None, + frequency_penalty: float | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_tokens: int | None = None, + n: int | None = None, + presence_penalty: float | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + top_p: float | None = None, + user: str | None = None, + guided_choice: list[str] | None = None, + prompt_logprobs: int | None = None, + suffix: str | None = None, + ) -> OpenAICompletion: + """ + Direct OpenAI completion API call. + """ + if guided_choice is not None: + logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.") + if prompt_logprobs is not None: + logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.") + + # TODO: fix openai_completion to return type compatible with OpenAI's API response + return await self.client.completions.create( # type: ignore[no-any-return] + **await prepare_openai_completion_params( + model=await self._get_provider_model_id(model), + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + suffix=suffix, + ) + ) + + async def openai_chat_completion( + self, + model: str, + messages: list[OpenAIMessageParam], + frequency_penalty: float | None = None, + function_call: str | dict[str, Any] | None = None, + functions: list[dict[str, Any]] | None = None, + logit_bias: dict[str, float] | None = None, + logprobs: bool | None = None, + max_completion_tokens: int | None = None, + max_tokens: int | None = None, + n: int | None = None, + parallel_tool_calls: bool | None = None, + presence_penalty: float | None = None, + response_format: OpenAIResponseFormatParam | None = None, + seed: int | None = None, + stop: str | list[str] | None = None, + stream: bool | None = None, + stream_options: dict[str, Any] | None = None, + temperature: float | None = None, + tool_choice: str | dict[str, Any] | None = None, + tools: list[dict[str, Any]] | None = None, + top_logprobs: int | None = None, + top_p: float | None = None, + user: str | None = None, + ) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]: + """ + Direct OpenAI chat completion API call. + """ + # Type ignore because return types are compatible + return await self.client.chat.completions.create( # type: ignore[no-any-return] + **await prepare_openai_completion_params( + model=await self._get_provider_model_id(model), + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + ) + + async def openai_embeddings( + self, + model: str, + input: str | list[str], + encoding_format: str | None = "float", + dimensions: int | None = None, + user: str | None = None, + ) -> OpenAIEmbeddingsResponse: + """ + Direct OpenAI embeddings API call. + """ + # Call OpenAI embeddings API with properly typed parameters + response = await self.client.embeddings.create( + model=await self._get_provider_model_id(model), + input=input, + encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN, + dimensions=dimensions if dimensions is not None else NOT_GIVEN, + user=user if user is not None else NOT_GIVEN, + ) + + data = [] + for i, embedding_data in enumerate(response.data): + data.append( + OpenAIEmbeddingData( + embedding=embedding_data.embedding, + index=i, + ) + ) + + usage = OpenAIEmbeddingUsage( + prompt_tokens=response.usage.prompt_tokens, + total_tokens=response.usage.total_tokens, + ) + + return OpenAIEmbeddingsResponse( + data=data, + model=response.model, + usage=usage, + ) + + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from OpenAI. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + # Direct model lookup - returns model or raises NotFoundError + await self.client.models.retrieve(model) + return True + except openai.NotFoundError: + # Model doesn't exist - this is expected for unavailable models + pass + except Exception as e: + # All other errors (auth, rate limit, network, etc.) + logger.warning(f"Failed to check model availability for {model}: {e}") + + return False diff --git a/tests/unit/providers/inference/test_inference_client_caching.py b/tests/unit/providers/inference/test_inference_client_caching.py index c9a931d47..ba36a3e3d 100644 --- a/tests/unit/providers/inference/test_inference_client_caching.py +++ b/tests/unit/providers/inference/test_inference_client_caching.py @@ -10,6 +10,8 @@ from unittest.mock import MagicMock from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter +from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig +from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.together.config import TogetherImplConfig @@ -50,7 +52,7 @@ def test_openai_provider_openai_client_caching(): with request_provider_data_context( {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} ): - openai_client = inference_adapter._get_openai_client() + openai_client = inference_adapter.client assert openai_client.api_key == api_key @@ -71,3 +73,18 @@ def test_together_provider_openai_client_caching(): assert together_client.client.api_key == api_key openai_client = inference_adapter._get_openai_client() assert openai_client.api_key == api_key + + +def test_llama_compat_provider_openai_client_caching(): + """Ensure the LlamaCompat provider does not cache api keys across client requests""" + config = LlamaCompatConfig() + inference_adapter = LlamaCompatInferenceAdapter(config) + + inference_adapter.__provider_spec__ = MagicMock() + inference_adapter.__provider_spec__.provider_data_validator = ( + "llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator" + ) + + for api_key in ["test1", "test2"]: + with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}): + assert inference_adapter.client.api_key == api_key