chore: create OpenAIMixin for inference providers with an OpenAI-compat API that need to implement openai_* methods (#2835)
Some checks failed
Coverage Badge / unit-tests (push) Failing after 3s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 6s
Integration Tests / discover-tests (push) Successful in 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 6s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 11s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 17s
Unit Tests / unit-tests (3.13) (push) Failing after 12s
Update ReadTheDocs / update-readthedocs (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 18s
Integration Tests / test-matrix (push) Failing after 18s
Pre-commit / pre-commit (push) Successful in 1m14s

# What does this PR do?

add an `OpenAIMixin` for use by inference providers who remote endpoints
support an OpenAI compatible API.

use is demonstrated by refactoring
- OpenAIInferenceAdapter
- NVIDIAInferenceAdapter (adds embedding support)
- LlamaCompatInferenceAdapter

## Test Plan

existing unit and integration tests
This commit is contained in:
Matthew Farrellee 2025-07-23 06:49:40 -04:00 committed by GitHub
parent fc67ad408a
commit e1ed152779
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 402 additions and 387 deletions

View file

@ -14,6 +14,41 @@ Here are some example PRs to help you get started:
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355) - [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665) - [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
## Inference Provider Patterns
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.
### OpenAIMixin
The `OpenAIMixin` class provides direct OpenAI API functionality for providers that work with OpenAI-compatible endpoints. It includes:
#### Direct API Methods
- **`openai_completion()`**: Legacy text completion API with full parameter support
- **`openai_chat_completion()`**: Chat completion API supporting streaming, tools, and function calling
- **`openai_embeddings()`**: Text embeddings generation with customizable encoding and dimensions
#### Model Management
- **`check_model_availability()`**: Queries the API endpoint to verify if a model exists and is accessible
#### Client Management
- **`client` property**: Automatically creates and configures AsyncOpenAI client instances using your provider's credentials
#### Required Implementation
To use `OpenAIMixin`, your provider must implement these abstract methods:
```python
@abstractmethod
def get_api_key(self) -> str:
"""Return the API key for authentication"""
pass
@abstractmethod
def get_base_url(self) -> str:
"""Return the OpenAI-compatible API base URL"""
pass
```
## Testing the Provider ## Testing the Provider

View file

@ -5,17 +5,27 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
Llama API Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
_config: LlamaCompatConfig _config: LlamaCompatConfig
def __init__(self, config: LlamaCompatConfig): def __init__(self, config: LlamaCompatConfig):
@ -28,32 +38,19 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
) )
self.config = config self.config = config
async def check_model_availability(self, model: str) -> bool: # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
""" """
Check if a specific model is available from Llama API. Get the base URL for OpenAI mixin.
:param model: The model identifier to check. :return: The Llama API base URL
:return: True if the model is available dynamically, False otherwise.
""" """
try: return self.config.openai_compat_api_base
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()
async def shutdown(self): async def shutdown(self):
await super().shutdown() await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,9 +7,8 @@
import logging import logging
import warnings import warnings
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError from openai import APIConnectionError, BadRequestError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
TextTruncation, TextTruncation,
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
convert_openai_chat_completion_choice, convert_openai_chat_completion_choice,
convert_openai_chat_completion_stream, convert_openai_chat_completion_stream,
prepare_openai_completion_params,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
from . import NVIDIAConfig from . import NVIDIAConfig
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
"""
NVIDIA Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability(). It also
must come before Inference to ensure that OpenAIMixin methods are available
in the Inference interface.
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
"""
def __init__(self, config: NVIDIAConfig) -> None: def __init__(self, config: NVIDIAConfig) -> None:
# TODO(mf): filter by available models # TODO(mf): filter by available models
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
self._config = config self._config = config
async def check_model_availability(self, model: str) -> bool: def get_api_key(self) -> str:
""" """
Check if a specific model is available. Get the API key for OpenAI mixin.
:param model: The model identifier to check. :return: The NVIDIA API key
:return: True if the model is available dynamically, False otherwise.
""" """
try: return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
await self._client.models.retrieve(model)
return True
except NotFoundError:
logger.error(f"Model {model} is not available")
except Exception as e:
logger.error(f"Failed to check model availability: {e}")
return False
@property def get_base_url(self) -> str:
def _client(self) -> AsyncOpenAI:
""" """
Returns an OpenAI client for the configured NVIDIA API endpoint. Get the base URL for OpenAI mixin.
:return: An OpenAI client :return: The NVIDIA API base URL
""" """
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
raise RuntimeError("Model store is not set")
model = await self.model_store.get_model(model_id)
if model is None:
raise ValueError(f"Model {model_id} is unknown")
return model.provider_model_id
async def completion( async def completion(
self, self,
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._client.completions.create(**request) response = await self.client.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type] extra_body["input_type"] = task_type_options[task_type]
try: try:
response = await self._client.embeddings.create( response = await self.client.embeddings.create(
model=provider_model_id, model=provider_model_id,
input=input, input=input,
extra_body=extra_body, extra_body=extra_body,
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# #
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def chat_completion( async def chat_completion(
self, self,
model_id: str, model_id: str,
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
) )
try: try:
response = await self._client.chat.completions.create(**request) response = await self.client.chat.completions.create(**request)
except APIConnectionError as e: except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
else: else:
# we pass n=1 to get only one completion # we pass n=1 to get only one completion
return convert_openai_chat_completion_choice(response.choices[0]) return convert_openai_chat_completion_choice(response.choices[0])
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
try:
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
provider_model_id = await self._get_provider_model_id(model)
params = await prepare_openai_completion_params(
model=provider_model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
try:
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -5,23 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig from .config import OpenAIConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
# #
# This OpenAI adapter implements Inference methods using two clients - # This OpenAI adapter implements Inference methods using two mixins -
# #
# | Inference Method | Implementation Source | # | Inference Method | Implementation Source |
# |----------------------------|--------------------------| # |----------------------------|--------------------------|
@ -39,11 +25,22 @@ logger = logging.getLogger(__name__)
# | embedding | LiteLLMOpenAIMixin | # | embedding | LiteLLMOpenAIMixin |
# | batch_completion | LiteLLMOpenAIMixin | # | batch_completion | LiteLLMOpenAIMixin |
# | batch_chat_completion | LiteLLMOpenAIMixin | # | batch_chat_completion | LiteLLMOpenAIMixin |
# | openai_completion | AsyncOpenAI | # | openai_completion | OpenAIMixin |
# | openai_chat_completion | AsyncOpenAI | # | openai_chat_completion | OpenAIMixin |
# | openai_embeddings | AsyncOpenAI | # | openai_embeddings | OpenAIMixin |
# #
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
OpenAI Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of ModelRegistryHelper.check_model_availability().
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
"""
def __init__(self, config: OpenAIConfig) -> None: def __init__(self, config: OpenAIConfig) -> None:
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
@ -60,191 +57,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak. # litellm specific model names, an abstraction leak.
self.is_openai_compat = True self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool: # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
""" """
Check if a specific model is available from OpenAI. Get the OpenAI API base URL.
:param model: The model identifier to check. Returns the standard OpenAI API base URL for direct OpenAI API calls.
:return: True if the model is available dynamically, False otherwise.
""" """
try: return "https://api.openai.com/v1"
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()
async def shutdown(self) -> None: async def shutdown(self) -> None:
await super().shutdown() await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if guided_choice is not None:
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
params = await prepare_openai_completion_params(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_id = (await self.model_store.get_model(model)).provider_resource_id
if model_id.startswith("openai/"):
model_id = model_id[len("openai/") :]
# Prepare parameters for OpenAI embeddings API
params = {
"model": model_id,
"input": input,
}
if encoding_format is not None:
params["encoding_format"] = encoding_format
if dimensions is not None:
params["dimensions"] = dimensions
if user is not None:
params["user"] = user
# Call OpenAI embeddings API
response = await self._get_openai_client().embeddings.create(**params)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)

View file

@ -10,12 +10,15 @@ from pydantic import BaseModel, Field
from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.common.errors import UnsupportedModelError
from llama_stack.apis.models import ModelType from llama_stack.apis.models import ModelType
from llama_stack.log import get_logger
from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.models.llama.sku_list import all_registered_models
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import ( from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
) )
logger = get_logger(name=__name__, category="core")
# TODO: this class is more confusing than useful right now. We need to make it # TODO: this class is more confusing than useful right now. We need to make it
# more closer to the Model class. # more closer to the Model class.
@ -98,6 +101,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
:param model: The model identifier to check. :param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise. :return: True if the model is available dynamically, False otherwise.
""" """
logger.info(
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
)
return False return False
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:

View file

@ -0,0 +1,272 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
import openai
from openai import NOT_GIVEN, AsyncOpenAI
from llama_stack.apis.inference import (
Model,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingData,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
class OpenAIMixin(ABC):
"""
Mixin class that provides OpenAI-specific functionality for inference providers.
This class handles direct OpenAI API calls using the AsyncOpenAI client.
This is an abstract base class that requires child classes to implement:
- get_api_key(): Method to retrieve the API key
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
Expected Dependencies:
- self.model_store: Injected by the Llama Stack distribution system at runtime.
This provides model registry functionality for looking up registered models.
The model_store is set in routing_tables/common.py during provider initialization.
"""
@abstractmethod
def get_api_key(self) -> str:
"""
Get the API key.
This method must be implemented by child classes to provide the API key
for authenticating with the OpenAI API or compatible endpoints.
:return: The API key as a string
"""
pass
@abstractmethod
def get_base_url(self) -> str:
"""
Get the OpenAI-compatible API base URL.
This method must be implemented by child classes to provide the base URL
for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1").
:return: The base URL as a string
"""
pass
@property
def client(self) -> AsyncOpenAI:
"""
Get an AsyncOpenAI client instance.
Uses the abstract methods get_api_key() and get_base_url() which must be
implemented by child classes.
"""
return AsyncOpenAI(
api_key=self.get_api_key(),
base_url=self.get_base_url(),
)
async def _get_provider_model_id(self, model: str) -> str:
"""
Get the provider-specific model ID from the model store.
This is a utility method that looks up the registered model and returns
the provider_resource_id that should be used for actual API calls.
:param model: The registered model name/identifier
:return: The provider-specific model ID (e.g., "gpt-4")
"""
# Look up the registered model to get the provider-specific model ID
# self.model_store is injected by the distribution system at runtime
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id")
return model_obj.provider_resource_id
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
"""
Direct OpenAI completion API call.
"""
if guided_choice is not None:
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
if prompt_logprobs is not None:
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
# TODO: fix openai_completion to return type compatible with OpenAI's API response
return await self.client.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
"""
Direct OpenAI chat completion API call.
"""
# Type ignore because return types are compatible
return await self.client.chat.completions.create( # type: ignore[no-any-return]
**await prepare_openai_completion_params(
model=await self._get_provider_model_id(model),
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
"""
Direct OpenAI embeddings API call.
"""
# Call OpenAI embeddings API with properly typed parameters
response = await self.client.embeddings.create(
model=await self._get_provider_model_id(model),
input=input,
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
user=user if user is not None else NOT_GIVEN,
)
data = []
for i, embedding_data in enumerate(response.data):
data.append(
OpenAIEmbeddingData(
embedding=embedding_data.embedding,
index=i,
)
)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
return OpenAIEmbeddingsResponse(
data=data,
model=response.model,
usage=usage,
)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
# Direct model lookup - returns model or raises NotFoundError
await self.client.models.retrieve(model)
return True
except openai.NotFoundError:
# Model doesn't exist - this is expected for unavailable models
pass
except Exception as e:
# All other errors (auth, rate limit, network, etc.)
logger.warning(f"Failed to check model availability for {model}: {e}")
return False

View file

@ -10,6 +10,8 @@ from unittest.mock import MagicMock
from llama_stack.distribution.request_headers import request_provider_data_context from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
@ -50,7 +52,7 @@ def test_openai_provider_openai_client_caching():
with request_provider_data_context( with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})} {"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
): ):
openai_client = inference_adapter._get_openai_client() openai_client = inference_adapter.client
assert openai_client.api_key == api_key assert openai_client.api_key == api_key
@ -71,3 +73,18 @@ def test_together_provider_openai_client_caching():
assert together_client.client.api_key == api_key assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client() openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key assert openai_client.api_key == api_key
def test_llama_compat_provider_openai_client_caching():
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
config = LlamaCompatConfig()
inference_adapter = LlamaCompatInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
assert inference_adapter.client.api_key == api_key