refactor: switched from query_available_models() to check_model_availability() in OpenAIInferenceAdapter and LlamaCompatInferenceAdapter

This commit is contained in:
r3v5 2025-07-15 18:58:28 +01:00
parent fa5935bd80
commit c473de6b4f
No known key found for this signature in database
GPG key ID: 7758B9F272DE67D9
3 changed files with 48 additions and 17 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
import logging
from llama_api_client import AsyncLlamaAPIClient
from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -27,20 +27,33 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
openai_compat_api_base=config.openai_compat_api_base,
)
self.config = config
self._llama_api_client = AsyncLlamaAPIClient(api_key=config.api_key)
async def query_available_models(self) -> list[str]:
"""Query available models from the Llama API."""
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from Llama API.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
available_models = await self._llama_api_client.models.list()
logger.info(f"Available models from Llama API: {available_models}")
return [model.id for model in available_models]
llama_api_client = self._get_llama_api_client()
retrieved_model = await llama_api_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e:
logger.warning(f"Failed to query available models from Llama API: {e}")
return []
logger.error(f"Failed to check model availability from Llama API: {e}")
return False
async def initialize(self):
await super().initialize()
async def shutdown(self):
await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -8,7 +8,7 @@ import logging
from collections.abc import AsyncIterator
from typing import Any
from openai import AsyncOpenAI
from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import (
OpenAIChatCompletion,
@ -60,16 +60,26 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak.
self.is_openai_compat = True
async def query_available_models(self) -> list[str]:
"""Query available models from the OpenAI API"""
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try:
openai_client = self._get_openai_client()
available_models = await openai_client.models.list()
logger.info(f"Available models from OpenAI: {available_models.data}")
return [model.id for model in available_models.data]
retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e:
logger.warning(f"Failed to query available models from OpenAI: {e}")
return []
logger.error(f"Failed to check model availability from OpenAI: {e}")
return False
async def initialize(self) -> None:
await super().initialize()

View file

@ -13,6 +13,7 @@ annotated-types==0.7.0
anyio==4.8.0
# via
# httpx
# llama-api-client
# llama-stack-client
# openai
# starlette
@ -49,6 +50,7 @@ deprecated==1.2.18
# opentelemetry-semantic-conventions
distro==1.9.0
# via
# llama-api-client
# llama-stack-client
# openai
ecdsa==0.19.1
@ -80,6 +82,7 @@ httpcore==1.0.9
# via httpx
httpx==0.28.1
# via
# llama-api-client
# llama-stack
# llama-stack-client
# openai
@ -101,6 +104,8 @@ jsonschema==4.23.0
# via llama-stack
jsonschema-specifications==2024.10.1
# via jsonschema
llama-api-client==0.1.2
# via llama-stack
llama-stack-client==0.2.15
# via llama-stack
markdown-it-py==3.0.0
@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy'
pydantic==2.10.6
# via
# fastapi
# llama-api-client
# llama-stack
# llama-stack-client
# openai
@ -215,6 +221,7 @@ six==1.17.0
sniffio==1.3.1
# via
# anyio
# llama-api-client
# llama-stack-client
# openai
starlette==0.45.3
@ -239,6 +246,7 @@ typing-extensions==4.12.2
# anyio
# fastapi
# huggingface-hub
# llama-api-client
# llama-stack-client
# openai
# opentelemetry-sdk