diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 2c45ddddd..5f9cb20b2 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import logging -from llama_api_client import AsyncLlamaAPIClient +from llama_api_client import AsyncLlamaAPIClient, NotFoundError from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin @@ -27,20 +27,33 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): openai_compat_api_base=config.openai_compat_api_base, ) self.config = config - self._llama_api_client = AsyncLlamaAPIClient(api_key=config.api_key) - async def query_available_models(self) -> list[str]: - """Query available models from the Llama API.""" + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from Llama API. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ try: - available_models = await self._llama_api_client.models.list() - logger.info(f"Available models from Llama API: {available_models}") - return [model.id for model in available_models] + llama_api_client = self._get_llama_api_client() + retrieved_model = await llama_api_client.models.retrieve(model) + logger.info(f"Model {retrieved_model.id} is available from Llama API") + return True + + except NotFoundError: + logger.error(f"Model {model} is not available from Llama API") + return False + except Exception as e: - logger.warning(f"Failed to query available models from Llama API: {e}") - return [] + logger.error(f"Failed to check model availability from Llama API: {e}") + return False async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() + + def _get_llama_api_client(self) -> AsyncLlamaAPIClient: + return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base) diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 535cf793a..7e167f621 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -8,7 +8,7 @@ import logging from collections.abc import AsyncIterator from typing import Any -from openai import AsyncOpenAI +from openai import AsyncOpenAI, NotFoundError from llama_stack.apis.inference import ( OpenAIChatCompletion, @@ -60,16 +60,26 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # litellm specific model names, an abstraction leak. self.is_openai_compat = True - async def query_available_models(self) -> list[str]: - """Query available models from the OpenAI API""" + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from OpenAI. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ try: openai_client = self._get_openai_client() - available_models = await openai_client.models.list() - logger.info(f"Available models from OpenAI: {available_models.data}") - return [model.id for model in available_models.data] + retrieved_model = await openai_client.models.retrieve(model) + logger.info(f"Model {retrieved_model.id} is available from OpenAI") + return True + + except NotFoundError: + logger.error(f"Model {model} is not available from OpenAI") + return False + except Exception as e: - logger.warning(f"Failed to query available models from OpenAI: {e}") - return [] + logger.error(f"Failed to check model availability from OpenAI: {e}") + return False async def initialize(self) -> None: await super().initialize() diff --git a/requirements.txt b/requirements.txt index eb97f7b4c..1106efac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ annotated-types==0.7.0 anyio==4.8.0 # via # httpx + # llama-api-client # llama-stack-client # openai # starlette @@ -49,6 +50,7 @@ deprecated==1.2.18 # opentelemetry-semantic-conventions distro==1.9.0 # via + # llama-api-client # llama-stack-client # openai ecdsa==0.19.1 @@ -80,6 +82,7 @@ httpcore==1.0.9 # via httpx httpx==0.28.1 # via + # llama-api-client # llama-stack # llama-stack-client # openai @@ -101,6 +104,8 @@ jsonschema==4.23.0 # via llama-stack jsonschema-specifications==2024.10.1 # via jsonschema +llama-api-client==0.1.2 + # via llama-stack llama-stack-client==0.2.15 # via llama-stack markdown-it-py==3.0.0 @@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy' pydantic==2.10.6 # via # fastapi + # llama-api-client # llama-stack # llama-stack-client # openai @@ -215,6 +221,7 @@ six==1.17.0 sniffio==1.3.1 # via # anyio + # llama-api-client # llama-stack-client # openai starlette==0.45.3 @@ -239,6 +246,7 @@ typing-extensions==4.12.2 # anyio # fastapi # huggingface-hub + # llama-api-client # llama-stack-client # openai # opentelemetry-sdk