refactor: switched from query_available_models() to check_model_availability() in OpenAIInferenceAdapter and LlamaCompatInferenceAdapter

This commit is contained in:
r3v5 2025-07-15 18:58:28 +01:00
parent fa5935bd80
commit c473de6b4f
No known key found for this signature in database
GPG key ID: 7758B9F272DE67D9
3 changed files with 48 additions and 17 deletions

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import logging import logging
from llama_api_client import AsyncLlamaAPIClient from llama_api_client import AsyncLlamaAPIClient, NotFoundError
from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -27,20 +27,33 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin):
openai_compat_api_base=config.openai_compat_api_base, openai_compat_api_base=config.openai_compat_api_base,
) )
self.config = config self.config = config
self._llama_api_client = AsyncLlamaAPIClient(api_key=config.api_key)
async def query_available_models(self) -> list[str]: async def check_model_availability(self, model: str) -> bool:
"""Query available models from the Llama API.""" """
Check if a specific model is available from Llama API.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try: try:
available_models = await self._llama_api_client.models.list() llama_api_client = self._get_llama_api_client()
logger.info(f"Available models from Llama API: {available_models}") retrieved_model = await llama_api_client.models.retrieve(model)
return [model.id for model in available_models] logger.info(f"Model {retrieved_model.id} is available from Llama API")
return True
except NotFoundError:
logger.error(f"Model {model} is not available from Llama API")
return False
except Exception as e: except Exception as e:
logger.warning(f"Failed to query available models from Llama API: {e}") logger.error(f"Failed to check model availability from Llama API: {e}")
return [] return False
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()
async def shutdown(self): async def shutdown(self):
await super().shutdown() await super().shutdown()
def _get_llama_api_client(self) -> AsyncLlamaAPIClient:
return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base)

View file

@ -8,7 +8,7 @@ import logging
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any from typing import Any
from openai import AsyncOpenAI from openai import AsyncOpenAI, NotFoundError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
OpenAIChatCompletion, OpenAIChatCompletion,
@ -60,16 +60,26 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# litellm specific model names, an abstraction leak. # litellm specific model names, an abstraction leak.
self.is_openai_compat = True self.is_openai_compat = True
async def query_available_models(self) -> list[str]: async def check_model_availability(self, model: str) -> bool:
"""Query available models from the OpenAI API""" """
Check if a specific model is available from OpenAI.
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
try: try:
openai_client = self._get_openai_client() openai_client = self._get_openai_client()
available_models = await openai_client.models.list() retrieved_model = await openai_client.models.retrieve(model)
logger.info(f"Available models from OpenAI: {available_models.data}") logger.info(f"Model {retrieved_model.id} is available from OpenAI")
return [model.id for model in available_models.data] return True
except NotFoundError:
logger.error(f"Model {model} is not available from OpenAI")
return False
except Exception as e: except Exception as e:
logger.warning(f"Failed to query available models from OpenAI: {e}") logger.error(f"Failed to check model availability from OpenAI: {e}")
return [] return False
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()

View file

@ -13,6 +13,7 @@ annotated-types==0.7.0
anyio==4.8.0 anyio==4.8.0
# via # via
# httpx # httpx
# llama-api-client
# llama-stack-client # llama-stack-client
# openai # openai
# starlette # starlette
@ -49,6 +50,7 @@ deprecated==1.2.18
# opentelemetry-semantic-conventions # opentelemetry-semantic-conventions
distro==1.9.0 distro==1.9.0
# via # via
# llama-api-client
# llama-stack-client # llama-stack-client
# openai # openai
ecdsa==0.19.1 ecdsa==0.19.1
@ -80,6 +82,7 @@ httpcore==1.0.9
# via httpx # via httpx
httpx==0.28.1 httpx==0.28.1
# via # via
# llama-api-client
# llama-stack # llama-stack
# llama-stack-client # llama-stack-client
# openai # openai
@ -101,6 +104,8 @@ jsonschema==4.23.0
# via llama-stack # via llama-stack
jsonschema-specifications==2024.10.1 jsonschema-specifications==2024.10.1
# via jsonschema # via jsonschema
llama-api-client==0.1.2
# via llama-stack
llama-stack-client==0.2.15 llama-stack-client==0.2.15
# via llama-stack # via llama-stack
markdown-it-py==3.0.0 markdown-it-py==3.0.0
@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy'
pydantic==2.10.6 pydantic==2.10.6
# via # via
# fastapi # fastapi
# llama-api-client
# llama-stack # llama-stack
# llama-stack-client # llama-stack-client
# openai # openai
@ -215,6 +221,7 @@ six==1.17.0
sniffio==1.3.1 sniffio==1.3.1
# via # via
# anyio # anyio
# llama-api-client
# llama-stack-client # llama-stack-client
# openai # openai
starlette==0.45.3 starlette==0.45.3
@ -239,6 +246,7 @@ typing-extensions==4.12.2
# anyio # anyio
# fastapi # fastapi
# huggingface-hub # huggingface-hub
# llama-api-client
# llama-stack-client # llama-stack-client
# openai # openai
# opentelemetry-sdk # opentelemetry-sdk