diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 29b5e889a..5f9cb20b2 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,16 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging -from llama_stack.providers.remote.inference.llama_openai_compat.config import ( - LlamaCompatConfig, -) -from llama_stack.providers.utils.inference.litellm_openai_mixin import ( - LiteLLMOpenAIMixin, -) +from llama_api_client import AsyncLlamaAPIClient, NotFoundError + +from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from .models import MODEL_ENTRIES +logger = logging.getLogger(__name__) + class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): _config: LlamaCompatConfig @@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from Llama API. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + llama_api_client = self._get_llama_api_client() + retrieved_model = await llama_api_client.models.retrieve(model) + logger.info(f"Model {retrieved_model.id} is available from Llama API") + return True + + except NotFoundError: + logger.error(f"Model {model} is not available from Llama API") + return False + + except Exception as e: + logger.error(f"Failed to check model availability from Llama API: {e}") + return False + async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() + + def _get_llama_api_client(self) -> AsyncLlamaAPIClient: + return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base) diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 818883919..7e167f621 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -8,7 +8,7 @@ import logging from collections.abc import AsyncIterator from typing import Any -from openai import AsyncOpenAI +from openai import AsyncOpenAI, NotFoundError from llama_stack.apis.inference import ( OpenAIChatCompletion, @@ -60,6 +60,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # litellm specific model names, an abstraction leak. self.is_openai_compat = True + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from OpenAI. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + openai_client = self._get_openai_client() + retrieved_model = await openai_client.models.retrieve(model) + logger.info(f"Model {retrieved_model.id} is available from OpenAI") + return True + + except NotFoundError: + logger.error(f"Model {model} is not available from OpenAI") + return False + + except Exception as e: + logger.error(f"Failed to check model availability from OpenAI: {e}") + return False + async def initialize(self) -> None: await super().initialize() diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 188e82125..0de267f6c 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -39,7 +38,6 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin( async def shutdown(self): pass - async def register_model(self, model: Model) -> Model: - model_id = self.get_provider_model_id(model.provider_resource_id) - if model_id is None: - raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) - return model - def get_litellm_model_name(self, model_id: str) -> str: # users may be using openai/ prefix in their model names. the openai/models.py did this by default. # model_id.startswith("openai/") is for backwards compatibility. diff --git a/pyproject.toml b/pyproject.toml index b557dfb9d..72f3a323f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "jinja2>=3.1.6", "jsonschema", "llama-stack-client>=0.2.15", + "llama-api-client>=0.1.2", "openai>=1.66", "prompt-toolkit", "python-dotenv", diff --git a/requirements.txt b/requirements.txt index eb97f7b4c..1106efac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ annotated-types==0.7.0 anyio==4.8.0 # via # httpx + # llama-api-client # llama-stack-client # openai # starlette @@ -49,6 +50,7 @@ deprecated==1.2.18 # opentelemetry-semantic-conventions distro==1.9.0 # via + # llama-api-client # llama-stack-client # openai ecdsa==0.19.1 @@ -80,6 +82,7 @@ httpcore==1.0.9 # via httpx httpx==0.28.1 # via + # llama-api-client # llama-stack # llama-stack-client # openai @@ -101,6 +104,8 @@ jsonschema==4.23.0 # via llama-stack jsonschema-specifications==2024.10.1 # via jsonschema +llama-api-client==0.1.2 + # via llama-stack llama-stack-client==0.2.15 # via llama-stack markdown-it-py==3.0.0 @@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy' pydantic==2.10.6 # via # fastapi + # llama-api-client # llama-stack # llama-stack-client # openai @@ -215,6 +221,7 @@ six==1.17.0 sniffio==1.3.1 # via # anyio + # llama-api-client # llama-stack-client # openai starlette==0.45.3 @@ -239,6 +246,7 @@ typing-extensions==4.12.2 # anyio # fastapi # huggingface-hub + # llama-api-client # llama-stack-client # openai # opentelemetry-sdk diff --git a/uv.lock b/uv.lock index 666cdf21f..7a9c5cab0 100644 --- a/uv.lock +++ b/uv.lock @@ -1268,6 +1268,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/f7/67689245f48b9e79bcd2f3a10a3690cb1918fb99fffd5a623ed2496bca66/litellm-1.74.2-py3-none-any.whl", hash = "sha256:29bb555b45128e4cc696e72921a6ec24e97b14e9b69e86eed6f155124ad629b1", size = 8587065 }, ] +[[package]] +name = "llama-api-client" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/78/875de3a16efd0442718ac47cc27319cd80cc5f38e12298e454e08611acc4/llama_api_client-0.1.2.tar.gz", hash = "sha256:709011f2d506009b1b3b3bceea1c84f2a3a7600df1420fb256e680fcd7251387", size = 113695, upload-time = "2025-06-27T19:56:14.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/08/5d7e6e7e6af5353391376288c200acacebb8e6b156d3636eae598a451673/llama_api_client-0.1.2-py3-none-any.whl", hash = "sha256:8ad6e10726f74b2302bfd766c61c41355a9ecf60f57cde2961882d22af998941", size = 84091, upload-time = "2025-06-27T19:56:12.8Z" }, +] + [[package]] name = "llama-stack" version = "0.2.15" @@ -1283,6 +1300,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "jinja2" }, { name = "jsonschema" }, + { name = "llama-api-client" }, { name = "llama-stack-client" }, { name = "openai" }, { name = "opentelemetry-exporter-otlp-proto-http" }, @@ -1398,6 +1416,7 @@ requires-dist = [ { name = "jsonschema" }, { name = "llama-stack-client", specifier = ">=0.2.15" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.15" }, + { name = "llama-api-client", specifier = ">=0.1.2" }, { name = "openai", specifier = ">=1.66" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" },