Merge branch 'main' into hide-non-openai-inference-apis

This commit is contained in:
Matthew Farrellee 2025-09-26 17:48:30 -04:00
commit 0e78cd5383
33 changed files with 2394 additions and 1723 deletions

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import pathlib
import pytest
@pytest.fixture
def image_path():
return pathlib.Path(__file__).parent / "dog.png"
@pytest.fixture
def base64_image_data(image_path):
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
async def test_openai_chat_completion_image_url(openai_client, vision_model_id):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
async def test_openai_chat_completion_image_data(openai_client, vision_model_id, base64_image_data):
message = {
"role": "user",
"content": [
{
"type": "image_url",
"image_url": {
"url": f"data:image/png;base64,{base64_image_data}",
},
},
{
"type": "text",
"text": "Describe what is in this image.",
},
],
}
response = openai_client.chat.completions.create(
model=vision_model_id,
messages=[message],
stream=False,
)
message_content = response.choices[0].message.content.lower().strip()
assert len(message_content) > 0
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})

View file

@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
@ -645,3 +646,25 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
# Cleanup
await table.shutdown()
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
exception_throwing_tool_groups_impl = ToolGroupsImpl()
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
table = ToolGroupsRoutingTable(
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
)
await table.initialize()
await table.register_tool_group(
toolgroup_id="test-toolgroup-exceptions",
provider_id="test_provider",
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
)
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
assert len(tools.data) == 0

View file

@ -4,11 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, PropertyMock, patch
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest
from llama_stack.apis.inference import Model
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
@ -43,8 +43,17 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin"""
return OpenAIMixinImpl()
"""Create a test instance of OpenAIMixin with mocked model_store"""
mixin_instance = OpenAIMixinImpl()
# just enough to satisfy _get_provider_model_id calls
mock_model_store = MagicMock()
mock_model = MagicMock()
mock_model.provider_resource_id = "test-provider-resource-id"
mock_model_store.get_model = AsyncMock(return_value=mock_model)
mixin_instance.model_store = mock_model_store
return mixin_instance
@pytest.fixture
@ -205,6 +214,74 @@ class TestOpenAIMixinCacheBehavior:
assert "final-mock-model-id" in mixin._model_cache
class TestOpenAIMixinImagePreprocessing:
"""Test cases for image preprocessing functionality"""
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
"""Test that image URLs are converted to base64 when download_images is True"""
mixin.download_images = True
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_called_once_with("http://example.com/image.jpg")
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[0]["type"] == "text"
assert content[1]["type"] == "image_url"
assert content[1]["image_url"]["url"] == ""
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
"""Test that image URLs are not modified when download_images is False"""
mixin.download_images = False # explicitly set to False
message = OpenAIUserMessageParam(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
],
)
mock_client = MagicMock()
mock_response = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
mock_localize.assert_not_called()
mock_client.chat.completions.create.assert_called_once()
call_args = mock_client.chat.completions.create.call_args
processed_messages = call_args[1]["messages"]
assert len(processed_messages) == 1
content = processed_messages[0]["content"]
assert len(content) == 2
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
class TestOpenAIMixinEmbeddingModelMetadata:
"""Test cases for embedding_model_metadata attribute functionality"""

View file

@ -129,7 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
assert result is not None
assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
async def test_get_all_objects(cached_disk_dist_registry):
@ -174,14 +174,10 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"),
valid_db.model_dump_json(),
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
"{not valid json",
)
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
@ -216,8 +212,7 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
)
await sqlite_kvstore.set(
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"),
valid_db.model_dump_json(),
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
)
await sqlite_kvstore.set(