From fa5935bd8085770f18db75794e411d7ace6377bb Mon Sep 17 00:00:00 2001 From: r3v5 Date: Mon, 14 Jul 2025 12:39:15 +0100 Subject: [PATCH] feat: created dynamic model registration for openai and llama openai compat remote inference providers fix: removed implementation of register_model() from LiteLLMOpenAIMixin, added log message to llama in query_available_models(), added llama-api-client dependency to pyproject.toml --- .../inference/llama_openai_compat/llama.py | 24 ++++++++++++++----- .../remote/inference/openai/openai.py | 11 +++++++++ .../utils/inference/litellm_openai_mixin.py | 8 ------- pyproject.toml | 1 + uv.lock | 19 +++++++++++++++ 5 files changed, 49 insertions(+), 14 deletions(-) diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 29b5e889a..2c45ddddd 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,16 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging -from llama_stack.providers.remote.inference.llama_openai_compat.config import ( - LlamaCompatConfig, -) -from llama_stack.providers.utils.inference.litellm_openai_mixin import ( - LiteLLMOpenAIMixin, -) +from llama_api_client import AsyncLlamaAPIClient + +from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from .models import MODEL_ENTRIES +logger = logging.getLogger(__name__) + class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): _config: LlamaCompatConfig @@ -26,6 +27,17 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): openai_compat_api_base=config.openai_compat_api_base, ) self.config = config + self._llama_api_client = AsyncLlamaAPIClient(api_key=config.api_key) + + async def query_available_models(self) -> list[str]: + """Query available models from the Llama API.""" + try: + available_models = await self._llama_api_client.models.list() + logger.info(f"Available models from Llama API: {available_models}") + return [model.id for model in available_models] + except Exception as e: + logger.warning(f"Failed to query available models from Llama API: {e}") + return [] async def initialize(self): await super().initialize() diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 818883919..535cf793a 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -60,6 +60,17 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # litellm specific model names, an abstraction leak. self.is_openai_compat = True + async def query_available_models(self) -> list[str]: + """Query available models from the OpenAI API""" + try: + openai_client = self._get_openai_client() + available_models = await openai_client.models.list() + logger.info(f"Available models from OpenAI: {available_models.data}") + return [model.id for model in available_models.data] + except Exception as e: + logger.warning(f"Failed to query available models from OpenAI: {e}") + return [] + async def initialize(self) -> None: await super().initialize() diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 188e82125..0de267f6c 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -39,7 +38,6 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin( async def shutdown(self): pass - async def register_model(self, model: Model) -> Model: - model_id = self.get_provider_model_id(model.provider_resource_id) - if model_id is None: - raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) - return model - def get_litellm_model_name(self, model_id: str) -> str: # users may be using openai/ prefix in their model names. the openai/models.py did this by default. # model_id.startswith("openai/") is for backwards compatibility. diff --git a/pyproject.toml b/pyproject.toml index b557dfb9d..72f3a323f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "jinja2>=3.1.6", "jsonschema", "llama-stack-client>=0.2.15", + "llama-api-client>=0.1.2", "openai>=1.66", "prompt-toolkit", "python-dotenv", diff --git a/uv.lock b/uv.lock index 666cdf21f..7a9c5cab0 100644 --- a/uv.lock +++ b/uv.lock @@ -1268,6 +1268,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/f7/67689245f48b9e79bcd2f3a10a3690cb1918fb99fffd5a623ed2496bca66/litellm-1.74.2-py3-none-any.whl", hash = "sha256:29bb555b45128e4cc696e72921a6ec24e97b14e9b69e86eed6f155124ad629b1", size = 8587065 }, ] +[[package]] +name = "llama-api-client" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/78/875de3a16efd0442718ac47cc27319cd80cc5f38e12298e454e08611acc4/llama_api_client-0.1.2.tar.gz", hash = "sha256:709011f2d506009b1b3b3bceea1c84f2a3a7600df1420fb256e680fcd7251387", size = 113695, upload-time = "2025-06-27T19:56:14.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/08/5d7e6e7e6af5353391376288c200acacebb8e6b156d3636eae598a451673/llama_api_client-0.1.2-py3-none-any.whl", hash = "sha256:8ad6e10726f74b2302bfd766c61c41355a9ecf60f57cde2961882d22af998941", size = 84091, upload-time = "2025-06-27T19:56:12.8Z" }, +] + [[package]] name = "llama-stack" version = "0.2.15" @@ -1283,6 +1300,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "jinja2" }, { name = "jsonschema" }, + { name = "llama-api-client" }, { name = "llama-stack-client" }, { name = "openai" }, { name = "opentelemetry-exporter-otlp-proto-http" }, @@ -1398,6 +1416,7 @@ requires-dist = [ { name = "jsonschema" }, { name = "llama-stack-client", specifier = ">=0.2.15" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.15" }, + { name = "llama-api-client", specifier = ">=0.1.2" }, { name = "openai", specifier = ">=1.66" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" },