mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-25 21:57:45 +00:00
chore: create OpenAIMixin for inference providers with an OpenAI-compat API that need to implement openai_* methods (#2835)
Some checks failed
Coverage Badge / unit-tests (push) Failing after 3s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 6s
Integration Tests / discover-tests (push) Successful in 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 6s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 11s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 17s
Unit Tests / unit-tests (3.13) (push) Failing after 12s
Update ReadTheDocs / update-readthedocs (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 18s
Integration Tests / test-matrix (push) Failing after 18s
Pre-commit / pre-commit (push) Successful in 1m14s
Some checks failed
Coverage Badge / unit-tests (push) Failing after 3s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 6s
Python Package Build Test / build (3.12) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (3.12, inline::milvus) (push) Failing after 6s
Integration Tests / discover-tests (push) Successful in 7s
Vector IO Integration Tests / test-matrix (3.12, remote::chromadb) (push) Failing after 6s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (3.13, inline::milvus) (push) Failing after 6s
Vector IO Integration Tests / test-matrix (3.13, inline::sqlite-vec) (push) Failing after 5s
Vector IO Integration Tests / test-matrix (3.12, inline::faiss) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, inline::faiss) (push) Failing after 9s
Unit Tests / unit-tests (3.12) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.12, inline::sqlite-vec) (push) Failing after 11s
Test External Providers / test-external-providers (venv) (push) Failing after 8s
Vector IO Integration Tests / test-matrix (3.13, remote::pgvector) (push) Failing after 12s
Vector IO Integration Tests / test-matrix (3.13, remote::chromadb) (push) Failing after 9s
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 17s
Unit Tests / unit-tests (3.13) (push) Failing after 12s
Update ReadTheDocs / update-readthedocs (push) Failing after 11s
Vector IO Integration Tests / test-matrix (3.12, remote::pgvector) (push) Failing after 16s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 18s
Integration Tests / test-matrix (push) Failing after 18s
Pre-commit / pre-commit (push) Successful in 1m14s
# What does this PR do? add an `OpenAIMixin` for use by inference providers who remote endpoints support an OpenAI compatible API. use is demonstrated by refactoring - OpenAIInferenceAdapter - NVIDIAInferenceAdapter (adds embedding support) - LlamaCompatInferenceAdapter ## Test Plan existing unit and integration tests
This commit is contained in:
parent
fc67ad408a
commit
e1ed152779
7 changed files with 402 additions and 387 deletions
|
@ -14,6 +14,41 @@ Here are some example PRs to help you get started:
|
||||||
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
|
- [Nvidia Inference Implementation](https://github.com/meta-llama/llama-stack/pull/355)
|
||||||
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
|
- [Model context protocol Tool Runtime](https://github.com/meta-llama/llama-stack/pull/665)
|
||||||
|
|
||||||
|
## Inference Provider Patterns
|
||||||
|
|
||||||
|
When implementing Inference providers for OpenAI-compatible APIs, Llama Stack provides several mixin classes to simplify development and ensure consistent behavior across providers.
|
||||||
|
|
||||||
|
### OpenAIMixin
|
||||||
|
|
||||||
|
The `OpenAIMixin` class provides direct OpenAI API functionality for providers that work with OpenAI-compatible endpoints. It includes:
|
||||||
|
|
||||||
|
#### Direct API Methods
|
||||||
|
- **`openai_completion()`**: Legacy text completion API with full parameter support
|
||||||
|
- **`openai_chat_completion()`**: Chat completion API supporting streaming, tools, and function calling
|
||||||
|
- **`openai_embeddings()`**: Text embeddings generation with customizable encoding and dimensions
|
||||||
|
|
||||||
|
#### Model Management
|
||||||
|
- **`check_model_availability()`**: Queries the API endpoint to verify if a model exists and is accessible
|
||||||
|
|
||||||
|
#### Client Management
|
||||||
|
- **`client` property**: Automatically creates and configures AsyncOpenAI client instances using your provider's credentials
|
||||||
|
|
||||||
|
#### Required Implementation
|
||||||
|
|
||||||
|
To use `OpenAIMixin`, your provider must implement these abstract methods:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@abstractmethod
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
"""Return the API key for authentication"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
"""Return the OpenAI-compatible API base URL"""
|
||||||
|
pass
|
||||||
|
```
|
||||||
|
|
||||||
## Testing the Provider
|
## Testing the Provider
|
||||||
|
|
||||||
|
|
|
@ -5,17 +5,27 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
|
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
"""
|
||||||
|
Llama API Inference Adapter for Llama Stack.
|
||||||
|
|
||||||
|
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||||
|
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||||
|
is used instead of ModelRegistryHelper.check_model_availability().
|
||||||
|
|
||||||
|
- OpenAIMixin.check_model_availability() queries the Llama API to check if a model exists
|
||||||
|
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||||
|
"""
|
||||||
|
|
||||||
_config: LlamaCompatConfig
|
_config: LlamaCompatConfig
|
||||||
|
|
||||||
def __init__(self, config: LlamaCompatConfig):
|
def __init__(self, config: LlamaCompatConfig):
|
||||||
|
@ -28,32 +38,19 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||||
|
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
Check if a specific model is available from Llama API.
|
Get the base URL for OpenAI mixin.
|
||||||
|
|
||||||
:param model: The model identifier to check.
|
:return: The Llama API base URL
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
|
||||||
"""
|
"""
|
||||||
try:
|
return self.config.openai_compat_api_base
|
||||||
llama_api_client = self._get_llama_api_client()
|
|
||||||
retrieved_model = await llama_api_client.models.retrieve(model)
|
|
||||||
logger.info(f"Model {retrieved_model.id} is available from Llama API")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
logger.error(f"Model {model} is not available from Llama API")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to check model availability from Llama API: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
|
|
||||||
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)
|
|
||||||
|
|
|
@ -7,9 +7,8 @@
|
||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError
|
from openai import APIConnectionError, BadRequestError
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -28,12 +27,6 @@ from llama_stack.apis.inference import (
|
||||||
Inference,
|
Inference,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
|
@ -47,8 +40,8 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_openai_chat_completion_choice,
|
convert_openai_chat_completion_choice,
|
||||||
convert_openai_chat_completion_stream,
|
convert_openai_chat_completion_stream,
|
||||||
prepare_openai_completion_params,
|
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||||
|
|
||||||
from . import NVIDIAConfig
|
from . import NVIDIAConfig
|
||||||
|
@ -64,7 +57,20 @@ from .utils import _is_nvidia_hosted
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
|
||||||
|
"""
|
||||||
|
NVIDIA Inference Adapter for Llama Stack.
|
||||||
|
|
||||||
|
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||||
|
ModelRegistryHelper to ensure that OpenAIMixin.check_model_availability()
|
||||||
|
is used instead of ModelRegistryHelper.check_model_availability(). It also
|
||||||
|
must come before Inference to ensure that OpenAIMixin methods are available
|
||||||
|
in the Inference interface.
|
||||||
|
|
||||||
|
- OpenAIMixin.check_model_availability() queries the NVIDIA API to check if a model exists
|
||||||
|
- ModelRegistryHelper.check_model_availability() just returns False and shows a warning
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: NVIDIAConfig) -> None:
|
def __init__(self, config: NVIDIAConfig) -> None:
|
||||||
# TODO(mf): filter by available models
|
# TODO(mf): filter by available models
|
||||||
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||||
|
@ -88,45 +94,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
def get_api_key(self) -> str:
|
||||||
"""
|
"""
|
||||||
Check if a specific model is available.
|
Get the API key for OpenAI mixin.
|
||||||
|
|
||||||
:param model: The model identifier to check.
|
:return: The NVIDIA API key
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
|
||||||
"""
|
"""
|
||||||
try:
|
return self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"
|
||||||
await self._client.models.retrieve(model)
|
|
||||||
return True
|
|
||||||
except NotFoundError:
|
|
||||||
logger.error(f"Model {model} is not available")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to check model availability: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
@property
|
def get_base_url(self) -> str:
|
||||||
def _client(self) -> AsyncOpenAI:
|
|
||||||
"""
|
"""
|
||||||
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
Get the base URL for OpenAI mixin.
|
||||||
|
|
||||||
:return: An OpenAI client
|
:return: The NVIDIA API base URL
|
||||||
"""
|
"""
|
||||||
|
return f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
|
||||||
|
|
||||||
return AsyncOpenAI(
|
|
||||||
base_url=base_url,
|
|
||||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
|
||||||
timeout=self._config.timeout,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
|
||||||
if not self.model_store:
|
|
||||||
raise RuntimeError("Model store is not set")
|
|
||||||
model = await self.model_store.get_model(model_id)
|
|
||||||
if model is None:
|
|
||||||
raise ValueError(f"Model {model_id} is unknown")
|
|
||||||
return model.provider_model_id
|
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -160,7 +142,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.completions.create(**request)
|
response = await self.client.completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
@ -213,7 +195,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.embeddings.create(
|
response = await self.client.embeddings.create(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
input=input,
|
input=input,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
|
@ -228,16 +210,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
#
|
#
|
||||||
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data])
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -274,7 +246,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._client.chat.completions.create(**request)
|
response = await self.client.chat.completions.create(**request)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
@ -283,112 +255,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
else:
|
else:
|
||||||
# we pass n=1 to get only one completion
|
# we pass n=1 to get only one completion
|
||||||
return convert_openai_chat_completion_choice(response.choices[0])
|
return convert_openai_chat_completion_choice(response.choices[0])
|
||||||
|
|
||||||
async def openai_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
provider_model_id = await self._get_provider_model_id(model)
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=provider_model_id,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self._client.completions.create(**params)
|
|
||||||
except APIConnectionError as e:
|
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
provider_model_id = await self._get_provider_model_id(model)
|
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=provider_model_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await self._client.chat.completions.create(**params)
|
|
||||||
except APIConnectionError as e:
|
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
|
||||||
|
|
|
@ -5,23 +5,9 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import AsyncIterator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI, NotFoundError
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
OpenAIChatCompletion,
|
|
||||||
OpenAIChatCompletionChunk,
|
|
||||||
OpenAICompletion,
|
|
||||||
OpenAIEmbeddingData,
|
|
||||||
OpenAIEmbeddingsResponse,
|
|
||||||
OpenAIEmbeddingUsage,
|
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import OpenAIConfig
|
from .config import OpenAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
@ -30,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
# This OpenAI adapter implements Inference methods using two clients -
|
# This OpenAI adapter implements Inference methods using two mixins -
|
||||||
#
|
#
|
||||||
# | Inference Method | Implementation Source |
|
# | Inference Method | Implementation Source |
|
||||||
# |----------------------------|--------------------------|
|
# |----------------------------|--------------------------|
|
||||||
|
@ -39,11 +25,22 @@ logger = logging.getLogger(__name__)
|
||||||
# | embedding | LiteLLMOpenAIMixin |
|
# | embedding | LiteLLMOpenAIMixin |
|
||||||
# | batch_completion | LiteLLMOpenAIMixin |
|
# | batch_completion | LiteLLMOpenAIMixin |
|
||||||
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
# | batch_chat_completion | LiteLLMOpenAIMixin |
|
||||||
# | openai_completion | AsyncOpenAI |
|
# | openai_completion | OpenAIMixin |
|
||||||
# | openai_chat_completion | AsyncOpenAI |
|
# | openai_chat_completion | OpenAIMixin |
|
||||||
# | openai_embeddings | AsyncOpenAI |
|
# | openai_embeddings | OpenAIMixin |
|
||||||
#
|
#
|
||||||
class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
class OpenAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
"""
|
||||||
|
OpenAI Inference Adapter for Llama Stack.
|
||||||
|
|
||||||
|
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||||
|
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||||
|
is used instead of ModelRegistryHelper.check_model_availability().
|
||||||
|
|
||||||
|
- OpenAIMixin.check_model_availability() queries the OpenAI API to check if a model exists
|
||||||
|
- ModelRegistryHelper.check_model_availability() (inherited by LiteLLMOpenAIMixin) just returns False and shows a warning
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: OpenAIConfig) -> None:
|
def __init__(self, config: OpenAIConfig) -> None:
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -60,191 +57,19 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
# litellm specific model names, an abstraction leak.
|
# litellm specific model names, an abstraction leak.
|
||||||
self.is_openai_compat = True
|
self.is_openai_compat = True
|
||||||
|
|
||||||
async def check_model_availability(self, model: str) -> bool:
|
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||||
|
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
"""
|
"""
|
||||||
Check if a specific model is available from OpenAI.
|
Get the OpenAI API base URL.
|
||||||
|
|
||||||
:param model: The model identifier to check.
|
Returns the standard OpenAI API base URL for direct OpenAI API calls.
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
|
||||||
"""
|
"""
|
||||||
try:
|
return "https://api.openai.com/v1"
|
||||||
openai_client = self._get_openai_client()
|
|
||||||
retrieved_model = await openai_client.models.retrieve(model)
|
|
||||||
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except NotFoundError:
|
|
||||||
logger.error(f"Model {model} is not available from OpenAI")
|
|
||||||
return False
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to check model availability from OpenAI: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
def _get_openai_client(self) -> AsyncOpenAI:
|
|
||||||
return AsyncOpenAI(
|
|
||||||
api_key=self.get_api_key(),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
prompt: str | list[str] | list[int] | list[list[int]],
|
|
||||||
best_of: int | None = None,
|
|
||||||
echo: bool | None = None,
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
guided_choice: list[str] | None = None,
|
|
||||||
prompt_logprobs: int | None = None,
|
|
||||||
suffix: str | None = None,
|
|
||||||
) -> OpenAICompletion:
|
|
||||||
if guided_choice is not None:
|
|
||||||
logging.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
|
||||||
if prompt_logprobs is not None:
|
|
||||||
logging.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
|
||||||
|
|
||||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
|
||||||
if model_id.startswith("openai/"):
|
|
||||||
model_id = model_id[len("openai/") :]
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_id,
|
|
||||||
prompt=prompt,
|
|
||||||
best_of=best_of,
|
|
||||||
echo=echo,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
suffix=suffix,
|
|
||||||
)
|
|
||||||
return await self._get_openai_client().completions.create(**params)
|
|
||||||
|
|
||||||
async def openai_chat_completion(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: list[OpenAIMessageParam],
|
|
||||||
frequency_penalty: float | None = None,
|
|
||||||
function_call: str | dict[str, Any] | None = None,
|
|
||||||
functions: list[dict[str, Any]] | None = None,
|
|
||||||
logit_bias: dict[str, float] | None = None,
|
|
||||||
logprobs: bool | None = None,
|
|
||||||
max_completion_tokens: int | None = None,
|
|
||||||
max_tokens: int | None = None,
|
|
||||||
n: int | None = None,
|
|
||||||
parallel_tool_calls: bool | None = None,
|
|
||||||
presence_penalty: float | None = None,
|
|
||||||
response_format: OpenAIResponseFormatParam | None = None,
|
|
||||||
seed: int | None = None,
|
|
||||||
stop: str | list[str] | None = None,
|
|
||||||
stream: bool | None = None,
|
|
||||||
stream_options: dict[str, Any] | None = None,
|
|
||||||
temperature: float | None = None,
|
|
||||||
tool_choice: str | dict[str, Any] | None = None,
|
|
||||||
tools: list[dict[str, Any]] | None = None,
|
|
||||||
top_logprobs: int | None = None,
|
|
||||||
top_p: float | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
|
||||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
|
||||||
if model_id.startswith("openai/"):
|
|
||||||
model_id = model_id[len("openai/") :]
|
|
||||||
params = await prepare_openai_completion_params(
|
|
||||||
model=model_id,
|
|
||||||
messages=messages,
|
|
||||||
frequency_penalty=frequency_penalty,
|
|
||||||
function_call=function_call,
|
|
||||||
functions=functions,
|
|
||||||
logit_bias=logit_bias,
|
|
||||||
logprobs=logprobs,
|
|
||||||
max_completion_tokens=max_completion_tokens,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
n=n,
|
|
||||||
parallel_tool_calls=parallel_tool_calls,
|
|
||||||
presence_penalty=presence_penalty,
|
|
||||||
response_format=response_format,
|
|
||||||
seed=seed,
|
|
||||||
stop=stop,
|
|
||||||
stream=stream,
|
|
||||||
stream_options=stream_options,
|
|
||||||
temperature=temperature,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
tools=tools,
|
|
||||||
top_logprobs=top_logprobs,
|
|
||||||
top_p=top_p,
|
|
||||||
user=user,
|
|
||||||
)
|
|
||||||
return await self._get_openai_client().chat.completions.create(**params)
|
|
||||||
|
|
||||||
async def openai_embeddings(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
input: str | list[str],
|
|
||||||
encoding_format: str | None = "float",
|
|
||||||
dimensions: int | None = None,
|
|
||||||
user: str | None = None,
|
|
||||||
) -> OpenAIEmbeddingsResponse:
|
|
||||||
model_id = (await self.model_store.get_model(model)).provider_resource_id
|
|
||||||
if model_id.startswith("openai/"):
|
|
||||||
model_id = model_id[len("openai/") :]
|
|
||||||
|
|
||||||
# Prepare parameters for OpenAI embeddings API
|
|
||||||
params = {
|
|
||||||
"model": model_id,
|
|
||||||
"input": input,
|
|
||||||
}
|
|
||||||
|
|
||||||
if encoding_format is not None:
|
|
||||||
params["encoding_format"] = encoding_format
|
|
||||||
if dimensions is not None:
|
|
||||||
params["dimensions"] = dimensions
|
|
||||||
if user is not None:
|
|
||||||
params["user"] = user
|
|
||||||
|
|
||||||
# Call OpenAI embeddings API
|
|
||||||
response = await self._get_openai_client().embeddings.create(**params)
|
|
||||||
|
|
||||||
data = []
|
|
||||||
for i, embedding_data in enumerate(response.data):
|
|
||||||
data.append(
|
|
||||||
OpenAIEmbeddingData(
|
|
||||||
embedding=embedding_data.embedding,
|
|
||||||
index=i,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
usage = OpenAIEmbeddingUsage(
|
|
||||||
prompt_tokens=response.usage.prompt_tokens,
|
|
||||||
total_tokens=response.usage.total_tokens,
|
|
||||||
)
|
|
||||||
|
|
||||||
return OpenAIEmbeddingsResponse(
|
|
||||||
data=data,
|
|
||||||
model=response.model,
|
|
||||||
usage=usage,
|
|
||||||
)
|
|
||||||
|
|
|
@ -10,12 +10,15 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import UnsupportedModelError
|
from llama_stack.apis.common.errors import UnsupportedModelError
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.sku_list import all_registered_models
|
from llama_stack.models.llama.sku_list import all_registered_models
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference import (
|
from llama_stack.providers.utils.inference import (
|
||||||
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
# TODO: this class is more confusing than useful right now. We need to make it
|
# TODO: this class is more confusing than useful right now. We need to make it
|
||||||
# more closer to the Model class.
|
# more closer to the Model class.
|
||||||
|
@ -98,6 +101,9 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
:param model: The model identifier to check.
|
:param model: The model identifier to check.
|
||||||
:return: True if the model is available dynamically, False otherwise.
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
"""
|
"""
|
||||||
|
logger.info(
|
||||||
|
f"check_model_availability is not implemented for {self.__class__.__name__}. Returning False by default."
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
|
272
llama_stack/providers/utils/inference/openai_mixin.py
Normal file
272
llama_stack/providers/utils/inference/openai_mixin.py
Normal file
|
@ -0,0 +1,272 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from collections.abc import AsyncIterator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from openai import NOT_GIVEN, AsyncOpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
Model,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAIEmbeddingData,
|
||||||
|
OpenAIEmbeddingsResponse,
|
||||||
|
OpenAIEmbeddingUsage,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
)
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIMixin(ABC):
|
||||||
|
"""
|
||||||
|
Mixin class that provides OpenAI-specific functionality for inference providers.
|
||||||
|
This class handles direct OpenAI API calls using the AsyncOpenAI client.
|
||||||
|
|
||||||
|
This is an abstract base class that requires child classes to implement:
|
||||||
|
- get_api_key(): Method to retrieve the API key
|
||||||
|
- get_base_url(): Method to retrieve the OpenAI-compatible API base URL
|
||||||
|
|
||||||
|
Expected Dependencies:
|
||||||
|
- self.model_store: Injected by the Llama Stack distribution system at runtime.
|
||||||
|
This provides model registry functionality for looking up registered models.
|
||||||
|
The model_store is set in routing_tables/common.py during provider initialization.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_api_key(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the API key.
|
||||||
|
|
||||||
|
This method must be implemented by child classes to provide the API key
|
||||||
|
for authenticating with the OpenAI API or compatible endpoints.
|
||||||
|
|
||||||
|
:return: The API key as a string
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the OpenAI-compatible API base URL.
|
||||||
|
|
||||||
|
This method must be implemented by child classes to provide the base URL
|
||||||
|
for the OpenAI API or compatible endpoints (e.g., "https://api.openai.com/v1").
|
||||||
|
|
||||||
|
:return: The base URL as a string
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> AsyncOpenAI:
|
||||||
|
"""
|
||||||
|
Get an AsyncOpenAI client instance.
|
||||||
|
|
||||||
|
Uses the abstract methods get_api_key() and get_base_url() which must be
|
||||||
|
implemented by child classes.
|
||||||
|
"""
|
||||||
|
return AsyncOpenAI(
|
||||||
|
api_key=self.get_api_key(),
|
||||||
|
base_url=self.get_base_url(),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_provider_model_id(self, model: str) -> str:
|
||||||
|
"""
|
||||||
|
Get the provider-specific model ID from the model store.
|
||||||
|
|
||||||
|
This is a utility method that looks up the registered model and returns
|
||||||
|
the provider_resource_id that should be used for actual API calls.
|
||||||
|
|
||||||
|
:param model: The registered model name/identifier
|
||||||
|
:return: The provider-specific model ID (e.g., "gpt-4")
|
||||||
|
"""
|
||||||
|
# Look up the registered model to get the provider-specific model ID
|
||||||
|
# self.model_store is injected by the distribution system at runtime
|
||||||
|
model_obj: Model = await self.model_store.get_model(model) # type: ignore[attr-defined]
|
||||||
|
# provider_resource_id is str | None, but we expect it to be str for OpenAI calls
|
||||||
|
if model_obj.provider_resource_id is None:
|
||||||
|
raise ValueError(f"Model {model} has no provider_resource_id")
|
||||||
|
return model_obj.provider_resource_id
|
||||||
|
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: str | list[str] | list[int] | list[list[int]],
|
||||||
|
best_of: int | None = None,
|
||||||
|
echo: bool | None = None,
|
||||||
|
frequency_penalty: float | None = None,
|
||||||
|
logit_bias: dict[str, float] | None = None,
|
||||||
|
logprobs: bool | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
n: int | None = None,
|
||||||
|
presence_penalty: float | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
stop: str | list[str] | None = None,
|
||||||
|
stream: bool | None = None,
|
||||||
|
stream_options: dict[str, Any] | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
top_p: float | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
guided_choice: list[str] | None = None,
|
||||||
|
prompt_logprobs: int | None = None,
|
||||||
|
suffix: str | None = None,
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
"""
|
||||||
|
Direct OpenAI completion API call.
|
||||||
|
"""
|
||||||
|
if guided_choice is not None:
|
||||||
|
logger.warning("guided_choice is not supported by the OpenAI API. Ignoring.")
|
||||||
|
if prompt_logprobs is not None:
|
||||||
|
logger.warning("prompt_logprobs is not supported by the OpenAI API. Ignoring.")
|
||||||
|
|
||||||
|
# TODO: fix openai_completion to return type compatible with OpenAI's API response
|
||||||
|
return await self.client.completions.create( # type: ignore[no-any-return]
|
||||||
|
**await prepare_openai_completion_params(
|
||||||
|
model=await self._get_provider_model_id(model),
|
||||||
|
prompt=prompt,
|
||||||
|
best_of=best_of,
|
||||||
|
echo=echo,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
suffix=suffix,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: list[OpenAIMessageParam],
|
||||||
|
frequency_penalty: float | None = None,
|
||||||
|
function_call: str | dict[str, Any] | None = None,
|
||||||
|
functions: list[dict[str, Any]] | None = None,
|
||||||
|
logit_bias: dict[str, float] | None = None,
|
||||||
|
logprobs: bool | None = None,
|
||||||
|
max_completion_tokens: int | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
n: int | None = None,
|
||||||
|
parallel_tool_calls: bool | None = None,
|
||||||
|
presence_penalty: float | None = None,
|
||||||
|
response_format: OpenAIResponseFormatParam | None = None,
|
||||||
|
seed: int | None = None,
|
||||||
|
stop: str | list[str] | None = None,
|
||||||
|
stream: bool | None = None,
|
||||||
|
stream_options: dict[str, Any] | None = None,
|
||||||
|
temperature: float | None = None,
|
||||||
|
tool_choice: str | dict[str, Any] | None = None,
|
||||||
|
tools: list[dict[str, Any]] | None = None,
|
||||||
|
top_logprobs: int | None = None,
|
||||||
|
top_p: float | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
|
||||||
|
"""
|
||||||
|
Direct OpenAI chat completion API call.
|
||||||
|
"""
|
||||||
|
# Type ignore because return types are compatible
|
||||||
|
return await self.client.chat.completions.create( # type: ignore[no-any-return]
|
||||||
|
**await prepare_openai_completion_params(
|
||||||
|
model=await self._get_provider_model_id(model),
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
function_call=function_call,
|
||||||
|
functions=functions,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
response_format=response_format,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tools=tools,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_embeddings(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
input: str | list[str],
|
||||||
|
encoding_format: str | None = "float",
|
||||||
|
dimensions: int | None = None,
|
||||||
|
user: str | None = None,
|
||||||
|
) -> OpenAIEmbeddingsResponse:
|
||||||
|
"""
|
||||||
|
Direct OpenAI embeddings API call.
|
||||||
|
"""
|
||||||
|
# Call OpenAI embeddings API with properly typed parameters
|
||||||
|
response = await self.client.embeddings.create(
|
||||||
|
model=await self._get_provider_model_id(model),
|
||||||
|
input=input,
|
||||||
|
encoding_format=encoding_format if encoding_format is not None else NOT_GIVEN,
|
||||||
|
dimensions=dimensions if dimensions is not None else NOT_GIVEN,
|
||||||
|
user=user if user is not None else NOT_GIVEN,
|
||||||
|
)
|
||||||
|
|
||||||
|
data = []
|
||||||
|
for i, embedding_data in enumerate(response.data):
|
||||||
|
data.append(
|
||||||
|
OpenAIEmbeddingData(
|
||||||
|
embedding=embedding_data.embedding,
|
||||||
|
index=i,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
usage = OpenAIEmbeddingUsage(
|
||||||
|
prompt_tokens=response.usage.prompt_tokens,
|
||||||
|
total_tokens=response.usage.total_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
return OpenAIEmbeddingsResponse(
|
||||||
|
data=data,
|
||||||
|
model=response.model,
|
||||||
|
usage=usage,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a specific model is available from OpenAI.
|
||||||
|
|
||||||
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Direct model lookup - returns model or raises NotFoundError
|
||||||
|
await self.client.models.retrieve(model)
|
||||||
|
return True
|
||||||
|
except openai.NotFoundError:
|
||||||
|
# Model doesn't exist - this is expected for unavailable models
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
# All other errors (auth, rate limit, network, etc.)
|
||||||
|
logger.warning(f"Failed to check model availability for {model}: {e}")
|
||||||
|
|
||||||
|
return False
|
|
@ -10,6 +10,8 @@ from unittest.mock import MagicMock
|
||||||
from llama_stack.distribution.request_headers import request_provider_data_context
|
from llama_stack.distribution.request_headers import request_provider_data_context
|
||||||
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
from llama_stack.providers.remote.inference.groq.config import GroqConfig
|
||||||
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
|
||||||
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
|
from llama_stack.providers.remote.inference.llama_openai_compat.llama import LlamaCompatInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||||
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
|
||||||
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
|
||||||
|
@ -50,7 +52,7 @@ def test_openai_provider_openai_client_caching():
|
||||||
with request_provider_data_context(
|
with request_provider_data_context(
|
||||||
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
|
||||||
):
|
):
|
||||||
openai_client = inference_adapter._get_openai_client()
|
openai_client = inference_adapter.client
|
||||||
assert openai_client.api_key == api_key
|
assert openai_client.api_key == api_key
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,3 +73,18 @@ def test_together_provider_openai_client_caching():
|
||||||
assert together_client.client.api_key == api_key
|
assert together_client.client.api_key == api_key
|
||||||
openai_client = inference_adapter._get_openai_client()
|
openai_client = inference_adapter._get_openai_client()
|
||||||
assert openai_client.api_key == api_key
|
assert openai_client.api_key == api_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_llama_compat_provider_openai_client_caching():
|
||||||
|
"""Ensure the LlamaCompat provider does not cache api keys across client requests"""
|
||||||
|
config = LlamaCompatConfig()
|
||||||
|
inference_adapter = LlamaCompatInferenceAdapter(config)
|
||||||
|
|
||||||
|
inference_adapter.__provider_spec__ = MagicMock()
|
||||||
|
inference_adapter.__provider_spec__.provider_data_validator = (
|
||||||
|
"llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator"
|
||||||
|
)
|
||||||
|
|
||||||
|
for api_key in ["test1", "test2"]:
|
||||||
|
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"llama_api_key": api_key})}):
|
||||||
|
assert inference_adapter.client.api_key == api_key
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue