Merge branch 'main' into allow-dynamic-models-nvidia

This commit is contained in:
Matthew Farrellee 2025-07-16 12:53:44 -04:00
commit 6173d7a308
71 changed files with 3107 additions and 2381 deletions

View file

@ -3,16 +3,17 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
LlamaCompatConfig,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from .models import MODEL_ENTRIES
logger = logging.getLogger(__name__)
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
_config: LlamaCompatConfig
@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
)
self.config = config
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from Llama API.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -7,7 +7,6 @@
import logging
import warnings
from collections.abc import AsyncIterator
from functools import lru_cache
from typing import Any
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
@ -98,41 +97,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
# If we can't retrieve the model, it's not available
return False
@lru_cache # noqa: B019
def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI:
@property
def _client(self) -> AsyncOpenAI:
"""
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
some models are hosted on different URLs. This function returns the appropriate client
for the given provider_model_id.
Returns an OpenAI client for the configured NVIDIA API endpoint.
This relies on lru_cache and self._default_client to avoid creating a new client for each request
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
:param provider_model_id: The provider model ID (optional, defaults to primary endpoint)
:return: An OpenAI client
"""
@lru_cache # noqa: B019
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
"""
Maintain a single OpenAI client per base_url.
"""
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
special_model_urls = {
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if provider_model_id and _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
return AsyncOpenAI(
base_url=base_url,
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
timeout=self._config.timeout,
)
async def _get_provider_model_id(self, model_id: str) -> str:
if not self.model_store:
@ -174,7 +153,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).completions.create(**request)
response = await self._client.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -227,7 +206,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._get_client(provider_model_id).embeddings.create(
response = await self._client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
@ -288,7 +267,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
response = await self._get_client(provider_model_id).chat.completions.create(**request)
response = await self._client.chat.completions.create(**request)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -344,7 +323,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).completions.create(**params)
return await self._client.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
@ -403,6 +382,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
)
try:
return await self._get_client(provider_model_id).chat.completions.create(**params)
return await self._client.chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e

View file

@ -8,7 +8,7 @@ import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
@ -60,6 +60,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
openai_client = self._get_openai_client()
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
async def initialize(self) -> None:
await super().initialize()