From 521865c388507ef244f76e0fdc249059ef65ea1e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 18 Sep 2025 05:17:11 -0400 Subject: [PATCH] feat: include all models from provider's /v1/models (#3471) # What does this PR do? this replaces the static model listing for any provider using OpenAIMixin currently - - anthropic - azure openai - gemini - groq - llama-api - nvidia - openai - sambanova - tgi - vertexai - vllm - not changed: together has its own impl ## Test Plan - new unit tests - manual for llama-api, openai, groq, gemini ``` for provider in llama-openai-compat openai groq gemini; do uv run llama stack build --image-type venv --providers inference=remote::provider --run & uv run --with llama-stack-client llama-stack-client models list | grep Total ``` results (17 sep 2025): - llama-api: 4 - openai: 86 - groq: 21 - gemini: 66 closes #3467 --- .../providers/utils/inference/openai_mixin.py | 43 ++-- .../inference/test_openai_base_url_config.py | 36 +++- .../utils/inference/test_openai_mixin.py | 183 ++++++++++++++++++ 3 files changed, 242 insertions(+), 20 deletions(-) create mode 100644 tests/unit/providers/utils/inference/test_openai_mixin.py diff --git a/llama_stack/providers/utils/inference/openai_mixin.py b/llama_stack/providers/utils/inference/openai_mixin.py index 938927d21..c57f65bca 100644 --- a/llama_stack/providers/utils/inference/openai_mixin.py +++ b/llama_stack/providers/utils/inference/openai_mixin.py @@ -9,7 +9,6 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator from typing import Any -import openai from openai import NOT_GIVEN, AsyncOpenAI from llama_stack.apis.inference import ( @@ -23,6 +22,7 @@ from llama_stack.apis.inference import ( OpenAIMessageParam, OpenAIResponseFormatParam, ) +from llama_stack.apis.models import ModelType from llama_stack.log import get_logger from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params @@ -50,6 +50,10 @@ class OpenAIMixin(ABC): # This is useful for providers that do not return a unique id in the response. overwrite_completion_id: bool = False + # Cache of available models keyed by model ID + # This is set in list_models() and used in check_model_availability() + _model_cache: dict[str, Model] = {} + @abstractmethod def get_api_key(self) -> str: """ @@ -296,22 +300,35 @@ class OpenAIMixin(ABC): usage=usage, ) + async def list_models(self) -> list[Model] | None: + """ + List available models from the provider's /v1/models endpoint. + + Also, caches the models in self._model_cache for use in check_model_availability(). + + :return: A list of Model instances representing available models. + """ + self._model_cache = { + m.id: Model( + # __provider_id__ is dynamically added by instantiate_provider in resolver.py + provider_id=self.__provider_id__, # type: ignore[attr-defined] + provider_resource_id=m.id, + identifier=m.id, + model_type=ModelType.llm, + ) + async for m in self.client.models.list() + } + + return list(self._model_cache.values()) + async def check_model_availability(self, model: str) -> bool: """ - Check if a specific model is available from OpenAI. + Check if a specific model is available from the provider's /v1/models. :param model: The model identifier to check. :return: True if the model is available dynamically, False otherwise. """ - try: - # Direct model lookup - returns model or raises NotFoundError - await self.client.models.retrieve(model) - return True - except openai.NotFoundError: - # Model doesn't exist - this is expected for unavailable models - pass - except Exception as e: - # All other errors (auth, rate limit, network, etc.) - logger.warning(f"Failed to check model availability for {model}: {e}") + if not self._model_cache: + await self.list_models() - return False + return model in self._model_cache diff --git a/tests/unit/providers/inference/test_openai_base_url_config.py b/tests/unit/providers/inference/test_openai_base_url_config.py index 150f6210b..903772f0c 100644 --- a/tests/unit/providers/inference/test_openai_base_url_config.py +++ b/tests/unit/providers/inference/test_openai_base_url_config.py @@ -5,7 +5,7 @@ # the root directory of this source tree. import os -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch from llama_stack.core.stack import replace_env_vars from llama_stack.providers.remote.inference.openai.config import OpenAIConfig @@ -80,11 +80,22 @@ class TestOpenAIBaseURLConfig: # Mock the get_api_key method adapter.get_api_key = MagicMock(return_value="test-key") - # Mock the AsyncOpenAI client and its models.retrieve method + # Mock a model object that will be returned by models.list() + mock_model = MagicMock() + mock_model.id = "gpt-4" + + # Create an async iterator that yields our mock model + async def mock_async_iterator(): + yield mock_model + + # Mock the AsyncOpenAI client and its models.list method mock_client = MagicMock() - mock_client.models.retrieve = AsyncMock(return_value=MagicMock()) + mock_client.models.list = MagicMock(return_value=mock_async_iterator()) mock_openai_class.return_value = mock_client + # Set the __provider_id__ attribute that's expected by list_models + adapter.__provider_id__ = "openai" + # Call check_model_availability and verify it returns True assert await adapter.check_model_availability("gpt-4") @@ -94,8 +105,8 @@ class TestOpenAIBaseURLConfig: base_url=custom_url, ) - # Verify the method was called and returned True - mock_client.models.retrieve.assert_called_once_with("gpt-4") + # Verify the models.list method was called + mock_client.models.list.assert_called_once() @patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"}) @patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI") @@ -110,11 +121,22 @@ class TestOpenAIBaseURLConfig: # Mock the get_api_key method adapter.get_api_key = MagicMock(return_value="test-key") - # Mock the AsyncOpenAI client + # Mock a model object that will be returned by models.list() + mock_model = MagicMock() + mock_model.id = "gpt-4" + + # Create an async iterator that yields our mock model + async def mock_async_iterator(): + yield mock_model + + # Mock the AsyncOpenAI client and its models.list method mock_client = MagicMock() - mock_client.models.retrieve = AsyncMock(return_value=MagicMock()) + mock_client.models.list = MagicMock(return_value=mock_async_iterator()) mock_openai_class.return_value = mock_client + # Set the __provider_id__ attribute that's expected by list_models + adapter.__provider_id__ = "openai" + # Call check_model_availability and verify it returns True assert await adapter.check_model_availability("gpt-4") diff --git a/tests/unit/providers/utils/inference/test_openai_mixin.py b/tests/unit/providers/utils/inference/test_openai_mixin.py new file mode 100644 index 000000000..93f82da19 --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_mixin.py @@ -0,0 +1,183 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import MagicMock, PropertyMock, patch + +import pytest + +from llama_stack.apis.inference import Model +from llama_stack.apis.models import ModelType +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin + + +# Test implementation of OpenAIMixin for testing purposes +class OpenAIMixinImpl(OpenAIMixin): + def __init__(self): + self.__provider_id__ = "test-provider" + + def get_api_key(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + def get_base_url(self) -> str: + raise NotImplementedError("This method should be mocked in tests") + + +@pytest.fixture +def mixin(): + """Create a test instance of OpenAIMixin""" + return OpenAIMixinImpl() + + +@pytest.fixture +def mock_models(): + """Create multiple mock OpenAI model objects""" + models = [MagicMock(id=id) for id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]] + return models + + +@pytest.fixture +def mock_client_with_models(mock_models): + """Create a mock client with models.list() set up to return mock_models""" + mock_client = MagicMock() + + async def mock_models_list(): + for model in mock_models: + yield model + + mock_client.models.list.return_value = mock_models_list() + return mock_client + + +@pytest.fixture +def mock_client_with_empty_models(): + """Create a mock client with models.list() set up to return empty list""" + mock_client = MagicMock() + + async def mock_empty_models_list(): + return + yield # Make it an async generator but don't yield anything + + mock_client.models.list.return_value = mock_empty_models_list() + return mock_client + + +@pytest.fixture +def mock_client_with_exception(): + """Create a mock client with models.list() set up to raise an exception""" + mock_client = MagicMock() + mock_client.models.list.side_effect = Exception("API Error") + return mock_client + + +@pytest.fixture +def mock_client_context(): + """Fixture that provides a context manager for mocking the OpenAI client""" + + def _mock_client_context(mixin, mock_client): + return patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client) + + return _mock_client_context + + +class TestOpenAIMixinListModels: + """Test cases for the list_models method""" + + async def test_list_models_success(self, mixin, mock_client_with_models, mock_client_context): + """Test successful model listing""" + assert len(mixin._model_cache) == 0 + + with mock_client_context(mixin, mock_client_with_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 3 + + model_ids = [model.identifier for model in result] + assert "some-mock-model-id" in model_ids + assert "another-mock-model-id" in model_ids + assert "final-mock-model-id" in model_ids + + for model in result: + assert model.provider_id == "test-provider" + assert model.model_type == ModelType.llm + assert model.provider_resource_id == model.identifier + + assert len(mixin._model_cache) == 3 + for model_id in ["some-mock-model-id", "another-mock-model-id", "final-mock-model-id"]: + assert model_id in mixin._model_cache + cached_model = mixin._model_cache[model_id] + assert cached_model.identifier == model_id + assert cached_model.provider_resource_id == model_id + + async def test_list_models_empty_response(self, mixin, mock_client_with_empty_models, mock_client_context): + """Test handling of empty model list""" + with mock_client_context(mixin, mock_client_with_empty_models): + result = await mixin.list_models() + + assert result is not None + assert len(result) == 0 + assert len(mixin._model_cache) == 0 + + +class TestOpenAIMixinCheckModelAvailability: + """Test cases for the check_model_availability method""" + + async def test_check_model_availability_with_cache(self, mixin, mock_client_with_models, mock_client_context): + """Test model availability check when cache is populated""" + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + await mixin.list_models() + mock_client_with_models.models.list.assert_called_once() + + assert await mixin.check_model_availability("some-mock-model-id") + assert await mixin.check_model_availability("another-mock-model-id") + assert await mixin.check_model_availability("final-mock-model-id") + assert not await mixin.check_model_availability("non-existent-model") + mock_client_with_models.models.list.assert_called_once() + + async def test_check_model_availability_without_cache(self, mixin, mock_client_with_models, mock_client_context): + """Test model availability check when cache is empty (calls list_models)""" + assert len(mixin._model_cache) == 0 + + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert await mixin.check_model_availability("some-mock-model-id") + mock_client_with_models.models.list.assert_called_once() + + assert len(mixin._model_cache) == 3 + assert "some-mock-model-id" in mixin._model_cache + + async def test_check_model_availability_model_not_found(self, mixin, mock_client_with_models, mock_client_context): + """Test model availability check for non-existent model""" + with mock_client_context(mixin, mock_client_with_models): + mock_client_with_models.models.list.assert_not_called() + assert not await mixin.check_model_availability("non-existent-model") + mock_client_with_models.models.list.assert_called_once() + + assert len(mixin._model_cache) == 3 + + +class TestOpenAIMixinCacheBehavior: + """Test cases for cache behavior and edge cases""" + + async def test_cache_overwrites_on_list_models_call(self, mixin, mock_client_with_models, mock_client_context): + """Test that calling list_models overwrites existing cache""" + initial_model = Model( + provider_id="test-provider", + provider_resource_id="old-model", + identifier="old-model", + model_type=ModelType.llm, + ) + mixin._model_cache = {"old-model": initial_model} + + with mock_client_context(mixin, mock_client_with_models): + await mixin.list_models() + + assert len(mixin._model_cache) == 3 + assert "old-model" not in mixin._model_cache + assert "some-mock-model-id" in mixin._model_cache + assert "another-mock-model-id" in mixin._model_cache + assert "final-mock-model-id" in mixin._model_cache