mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-25 19:48:04 +00:00
Merge branch 'main' into allow-dynamic-models-nvidia
This commit is contained in:
commit
6173d7a308
71 changed files with 3107 additions and 2381 deletions
|
|
@ -3,16 +3,17 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import logging
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import (
|
||||
LlamaCompatConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||
LiteLLMOpenAIMixin,
|
||||
)
|
||||
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
|
||||
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from .models import MODEL_ENTRIES
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: LlamaCompatConfig
|
||||
|
|
@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
)
|
||||
self.config = config
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from Llama API.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
try:
|
||||
llama_api_client = self._get_llama_api_client()
|
||||
retrieved_model = await llama_api_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from Llama API")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from Llama API")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from Llama API: {e}")
|
||||
return False
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
||||
|
||||
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
|
||||
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@
|
|||
import logging
|
||||
import warnings
|
||||
from collections.abc import AsyncIterator
|
||||
from functools import lru_cache
|
||||
from typing import Any
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, BadRequestError
|
||||
|
|
@ -98,41 +97,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
# If we can't retrieve the model, it's not available
|
||||
return False
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client(self, provider_model_id: str | None = None) -> AsyncOpenAI:
|
||||
@property
|
||||
def _client(self) -> AsyncOpenAI:
|
||||
"""
|
||||
For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However,
|
||||
some models are hosted on different URLs. This function returns the appropriate client
|
||||
for the given provider_model_id.
|
||||
Returns an OpenAI client for the configured NVIDIA API endpoint.
|
||||
|
||||
This relies on lru_cache and self._default_client to avoid creating a new client for each request
|
||||
or for each model that is hosted on https://integrate.api.nvidia.com/v1.
|
||||
|
||||
:param provider_model_id: The provider model ID (optional, defaults to primary endpoint)
|
||||
:return: An OpenAI client
|
||||
"""
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _get_client_for_base_url(base_url: str) -> AsyncOpenAI:
|
||||
"""
|
||||
Maintain a single OpenAI client per base_url.
|
||||
"""
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
special_model_urls = {
|
||||
"meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct",
|
||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||
}
|
||||
|
||||
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||
|
||||
if provider_model_id and _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||
base_url = special_model_urls[provider_model_id]
|
||||
return _get_client_for_base_url(base_url)
|
||||
return AsyncOpenAI(
|
||||
base_url=base_url,
|
||||
api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"),
|
||||
timeout=self._config.timeout,
|
||||
)
|
||||
|
||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||
if not self.model_store:
|
||||
|
|
@ -174,7 +153,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).completions.create(**request)
|
||||
response = await self._client.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -227,7 +206,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
extra_body["input_type"] = task_type_options[task_type]
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).embeddings.create(
|
||||
response = await self._client.embeddings.create(
|
||||
model=provider_model_id,
|
||||
input=input,
|
||||
extra_body=extra_body,
|
||||
|
|
@ -288,7 +267,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
response = await self._get_client(provider_model_id).chat.completions.create(**request)
|
||||
response = await self._client.chat.completions.create(**request)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -344,7 +323,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).completions.create(**params)
|
||||
return await self._client.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
||||
|
|
@ -403,6 +382,6 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
)
|
||||
|
||||
try:
|
||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||
return await self._client.chat.completions.create(**params)
|
||||
except APIConnectionError as e:
|
||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import logging
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai import AsyncOpenAI, NotFoundError
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIChatCompletion,
|
||||
|
|
@ -60,6 +60,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
|||
# litellm specific model names, an abstraction leak.
|
||||
self.is_openai_compat = True
|
||||
|
||||
async def check_model_availability(self, model: str) -> bool:
|
||||
"""
|
||||
Check if a specific model is available from OpenAI.
|
||||
|
||||
:param model: The model identifier to check.
|
||||
:return: True if the model is available dynamically, False otherwise.
|
||||
"""
|
||||
try:
|
||||
openai_client = self._get_openai_client()
|
||||
retrieved_model = await openai_client.models.retrieve(model)
|
||||
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
|
||||
return True
|
||||
|
||||
except NotFoundError:
|
||||
logger.error(f"Model {model} is not available from OpenAI")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check model availability from OpenAI: {e}")
|
||||
return False
|
||||
|
||||
async def initialize(self) -> None:
|
||||
await super().initialize()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue