mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-11 05:38:38 +00:00
feat: include all models from provider's /v1/models
this replaces the static model listing for any provider using OpenAIMixin test with - - new unit tests - manual for llama-api, openai, groq, gemini ``` for provider in llama-openai-compat openai groq gemini; do uv run llama stack build --image-type venv --providers inference=remote::provider --run & uv run --with llama-stack-client llama-stack-client models list | grep Total ``` results (17 sep 2025): - llama-api: 4 - openai: 86 - groq: 21 - gemini: 66
This commit is contained in:
parent
9acf49753e
commit
c35606facb
3 changed files with 242 additions and 20 deletions
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
|
||||
|
@ -80,11 +80,22 @@ class TestOpenAIBaseURLConfig:
|
|||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.retrieve method
|
||||
# Mock a model object that will be returned by models.list()
|
||||
mock_model = MagicMock()
|
||||
mock_model.id = "gpt-4"
|
||||
|
||||
# Create an async iterator that yields our mock model
|
||||
async def mock_async_iterator():
|
||||
yield mock_model
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.list method
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Set the __provider_id__ attribute that's expected by list_models
|
||||
adapter.__provider_id__ = "openai"
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
|
@ -94,8 +105,8 @@ class TestOpenAIBaseURLConfig:
|
|||
base_url=custom_url,
|
||||
)
|
||||
|
||||
# Verify the method was called and returned True
|
||||
mock_client.models.retrieve.assert_called_once_with("gpt-4")
|
||||
# Verify the models.list method was called
|
||||
mock_client.models.list.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"OPENAI_BASE_URL": "https://proxy.openai.com/v1"})
|
||||
@patch("llama_stack.providers.utils.inference.openai_mixin.AsyncOpenAI")
|
||||
|
@ -110,11 +121,22 @@ class TestOpenAIBaseURLConfig:
|
|||
# Mock the get_api_key method
|
||||
adapter.get_api_key = MagicMock(return_value="test-key")
|
||||
|
||||
# Mock the AsyncOpenAI client
|
||||
# Mock a model object that will be returned by models.list()
|
||||
mock_model = MagicMock()
|
||||
mock_model.id = "gpt-4"
|
||||
|
||||
# Create an async iterator that yields our mock model
|
||||
async def mock_async_iterator():
|
||||
yield mock_model
|
||||
|
||||
# Mock the AsyncOpenAI client and its models.list method
|
||||
mock_client = MagicMock()
|
||||
mock_client.models.retrieve = AsyncMock(return_value=MagicMock())
|
||||
mock_client.models.list = MagicMock(return_value=mock_async_iterator())
|
||||
mock_openai_class.return_value = mock_client
|
||||
|
||||
# Set the __provider_id__ attribute that's expected by list_models
|
||||
adapter.__provider_id__ = "openai"
|
||||
|
||||
# Call check_model_availability and verify it returns True
|
||||
assert await adapter.check_model_availability("gpt-4")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue