mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-22 22:42:25 +00:00
refactor: switched from query_available_models() to check_model_availability() in OpenAIInferenceAdapter and LlamaCompatInferenceAdapter
This commit is contained in:
parent
fa5935bd80
commit
c473de6b4f
3 changed files with 48 additions and 17 deletions
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from llama_api_client import AsyncLlamaAPIClient
|
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
@ -27,20 +27,33 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
openai_compat_api_base=config.openai_compat_api_base,
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
self._llama_api_client = AsyncLlamaAPIClient(api_key=config.api_key)
|
|
||||||
|
|
||||||
async def query_available_models(self) -> list[str]:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""Query available models from the Llama API."""
|
"""
|
||||||
|
Check if a specific model is available from Llama API.
|
||||||
|
|
||||||
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
available_models = await self._llama_api_client.models.list()
|
llama_api_client = self._get_llama_api_client()
|
||||||
logger.info(f"Available models from Llama API: {available_models}")
|
retrieved_model = await llama_api_client.models.retrieve(model)
|
||||||
return [model.id for model in available_models]
|
logger.info(f"Model {retrieved_model.id} is available from Llama API")
|
||||||
|
return True
|
||||||
|
|
||||||
|
except NotFoundError:
|
||||||
|
logger.error(f"Model {model} is not available from Llama API")
|
||||||
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to query available models from Llama API: {e}")
|
logger.error(f"Failed to check model availability from Llama API: {e}")
|
||||||
return []
|
return False
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
await super().shutdown()
|
await super().shutdown()
|
||||||
|
|
||||||
|
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
|
||||||
|
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import logging
|
||||||
from collections.abc import AsyncIterator
|
from collections.abc import AsyncIterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI, NotFoundError
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
|
@ -60,16 +60,26 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
# litellm specific model names, an abstraction leak.
|
# litellm specific model names, an abstraction leak.
|
||||||
self.is_openai_compat = True
|
self.is_openai_compat = True
|
||||||
|
|
||||||
async def query_available_models(self) -> list[str]:
|
async def check_model_availability(self, model: str) -> bool:
|
||||||
"""Query available models from the OpenAI API"""
|
"""
|
||||||
|
Check if a specific model is available from OpenAI.
|
||||||
|
|
||||||
|
:param model: The model identifier to check.
|
||||||
|
:return: True if the model is available dynamically, False otherwise.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
openai_client = self._get_openai_client()
|
openai_client = self._get_openai_client()
|
||||||
available_models = await openai_client.models.list()
|
retrieved_model = await openai_client.models.retrieve(model)
|
||||||
logger.info(f"Available models from OpenAI: {available_models.data}")
|
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
|
||||||
return [model.id for model in available_models.data]
|
return True
|
||||||
|
|
||||||
|
except NotFoundError:
|
||||||
|
logger.error(f"Model {model} is not available from OpenAI")
|
||||||
|
return False
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to query available models from OpenAI: {e}")
|
logger.error(f"Failed to check model availability from OpenAI: {e}")
|
||||||
return []
|
return False
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ annotated-types==0.7.0
|
||||||
anyio==4.8.0
|
anyio==4.8.0
|
||||||
# via
|
# via
|
||||||
# httpx
|
# httpx
|
||||||
|
# llama-api-client
|
||||||
# llama-stack-client
|
# llama-stack-client
|
||||||
# openai
|
# openai
|
||||||
# starlette
|
# starlette
|
||||||
|
|
@ -49,6 +50,7 @@ deprecated==1.2.18
|
||||||
# opentelemetry-semantic-conventions
|
# opentelemetry-semantic-conventions
|
||||||
distro==1.9.0
|
distro==1.9.0
|
||||||
# via
|
# via
|
||||||
|
# llama-api-client
|
||||||
# llama-stack-client
|
# llama-stack-client
|
||||||
# openai
|
# openai
|
||||||
ecdsa==0.19.1
|
ecdsa==0.19.1
|
||||||
|
|
@ -80,6 +82,7 @@ httpcore==1.0.9
|
||||||
# via httpx
|
# via httpx
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
# via
|
# via
|
||||||
|
# llama-api-client
|
||||||
# llama-stack
|
# llama-stack
|
||||||
# llama-stack-client
|
# llama-stack-client
|
||||||
# openai
|
# openai
|
||||||
|
|
@ -101,6 +104,8 @@ jsonschema==4.23.0
|
||||||
# via llama-stack
|
# via llama-stack
|
||||||
jsonschema-specifications==2024.10.1
|
jsonschema-specifications==2024.10.1
|
||||||
# via jsonschema
|
# via jsonschema
|
||||||
|
llama-api-client==0.1.2
|
||||||
|
# via llama-stack
|
||||||
llama-stack-client==0.2.15
|
llama-stack-client==0.2.15
|
||||||
# via llama-stack
|
# via llama-stack
|
||||||
markdown-it-py==3.0.0
|
markdown-it-py==3.0.0
|
||||||
|
|
@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy'
|
||||||
pydantic==2.10.6
|
pydantic==2.10.6
|
||||||
# via
|
# via
|
||||||
# fastapi
|
# fastapi
|
||||||
|
# llama-api-client
|
||||||
# llama-stack
|
# llama-stack
|
||||||
# llama-stack-client
|
# llama-stack-client
|
||||||
# openai
|
# openai
|
||||||
|
|
@ -215,6 +221,7 @@ six==1.17.0
|
||||||
sniffio==1.3.1
|
sniffio==1.3.1
|
||||||
# via
|
# via
|
||||||
# anyio
|
# anyio
|
||||||
|
# llama-api-client
|
||||||
# llama-stack-client
|
# llama-stack-client
|
||||||
# openai
|
# openai
|
||||||
starlette==0.45.3
|
starlette==0.45.3
|
||||||
|
|
@ -239,6 +246,7 @@ typing-extensions==4.12.2
|
||||||
# anyio
|
# anyio
|
||||||
# fastapi
|
# fastapi
|
||||||
# huggingface-hub
|
# huggingface-hub
|
||||||
|
# llama-api-client
|
||||||
# llama-stack-client
|
# llama-stack-client
|
||||||
# openai
|
# openai
|
||||||
# opentelemetry-sdk
|
# opentelemetry-sdk
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue