mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into hide-non-openai-inference-apis
This commit is contained in:
commit
0e78cd5383
33 changed files with 2394 additions and 1723 deletions
77
tests/integration/inference/test_openai_vision_inference.py
Normal file
77
tests/integration/inference/test_openai_vision_inference.py
Normal file
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import base64
|
||||
import pathlib
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def image_path():
|
||||
return pathlib.Path(__file__).parent / "dog.png"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base64_image_data(image_path):
|
||||
return base64.b64encode(image_path.read_bytes()).decode("utf-8")
|
||||
|
||||
|
||||
async def test_openai_chat_completion_image_url(openai_client, vision_model_id):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/integration/inference/dog.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe what is in this image.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=vision_model_id,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
message_content = response.choices[0].message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
||||
|
||||
|
||||
async def test_openai_chat_completion_image_data(openai_client, vision_model_id, base64_image_data):
|
||||
message = {
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/png;base64,{base64_image_data}",
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Describe what is in this image.",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
response = openai_client.chat.completions.create(
|
||||
model=vision_model_id,
|
||||
messages=[message],
|
||||
stream=False,
|
||||
)
|
||||
|
||||
message_content = response.choices[0].message.content.lower().strip()
|
||||
assert len(message_content) > 0
|
||||
assert any(expected in message_content for expected in {"dog", "puppy", "pup"})
|
|
@ -10,6 +10,7 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.type_system import NumberType
|
||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||
from llama_stack.apis.datatypes import Api
|
||||
|
@ -645,3 +646,25 @@ async def test_models_source_interaction_cleanup_provider_models(cached_disk_dis
|
|||
|
||||
# Cleanup
|
||||
await table.shutdown()
|
||||
|
||||
|
||||
async def test_tool_groups_routing_table_exception_handling(cached_disk_dist_registry):
|
||||
"""Test that the tool group routing table handles exceptions when listing tools, like if an MCP server is unreachable."""
|
||||
|
||||
exception_throwing_tool_groups_impl = ToolGroupsImpl()
|
||||
exception_throwing_tool_groups_impl.list_runtime_tools = AsyncMock(side_effect=Exception("Test exception"))
|
||||
|
||||
table = ToolGroupsRoutingTable(
|
||||
{"test_provider": exception_throwing_tool_groups_impl}, cached_disk_dist_registry, {}
|
||||
)
|
||||
await table.initialize()
|
||||
|
||||
await table.register_tool_group(
|
||||
toolgroup_id="test-toolgroup-exceptions",
|
||||
provider_id="test_provider",
|
||||
mcp_endpoint=URL(uri="http://localhost:8479/foo/bar"),
|
||||
)
|
||||
|
||||
tools = await table.list_tools(toolgroup_id="test-toolgroup-exceptions")
|
||||
|
||||
assert len(tools.data) == 0
|
||||
|
|
|
@ -4,11 +4,11 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Model
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
@ -43,8 +43,17 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixin):
|
|||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin"""
|
||||
return OpenAIMixinImpl()
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
mixin_instance = OpenAIMixinImpl()
|
||||
|
||||
# just enough to satisfy _get_provider_model_id calls
|
||||
mock_model_store = MagicMock()
|
||||
mock_model = MagicMock()
|
||||
mock_model.provider_resource_id = "test-provider-resource-id"
|
||||
mock_model_store.get_model = AsyncMock(return_value=mock_model)
|
||||
mixin_instance.model_store = mock_model_store
|
||||
|
||||
return mixin_instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -205,6 +214,74 @@ class TestOpenAIMixinCacheBehavior:
|
|||
assert "final-mock-model-id" in mixin._model_cache
|
||||
|
||||
|
||||
class TestOpenAIMixinImagePreprocessing:
|
||||
"""Test cases for image preprocessing functionality"""
|
||||
|
||||
async def test_openai_chat_completion_with_image_preprocessing_enabled(self, mixin):
|
||||
"""Test that image URLs are converted to base64 when download_images is True"""
|
||||
mixin.download_images = True
|
||||
|
||||
message = OpenAIUserMessageParam(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||
],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
|
||||
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
processed_messages = call_args[1]["messages"]
|
||||
assert len(processed_messages) == 1
|
||||
content = processed_messages[0]["content"]
|
||||
assert len(content) == 2
|
||||
assert content[0]["type"] == "text"
|
||||
assert content[1]["type"] == "image_url"
|
||||
assert content[1]["image_url"]["url"] == "data:image/jpeg;base64,ZmFrZV9pbWFnZV9kYXRh"
|
||||
|
||||
async def test_openai_chat_completion_with_image_preprocessing_disabled(self, mixin):
|
||||
"""Test that image URLs are not modified when download_images is False"""
|
||||
mixin.download_images = False # explicitly set to False
|
||||
|
||||
message = OpenAIUserMessageParam(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "http://example.com/image.jpg"}},
|
||||
],
|
||||
)
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
|
||||
mock_localize.assert_not_called()
|
||||
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
processed_messages = call_args[1]["messages"]
|
||||
assert len(processed_messages) == 1
|
||||
content = processed_messages[0]["content"]
|
||||
assert len(content) == 2
|
||||
assert content[1]["image_url"]["url"] == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
class TestOpenAIMixinEmbeddingModelMetadata:
|
||||
"""Test cases for embedding_model_metadata attribute functionality"""
|
||||
|
||||
|
|
|
@ -129,7 +129,7 @@ async def test_duplicate_provider_registration(cached_disk_dist_registry):
|
|||
|
||||
result = await cached_disk_dist_registry.get("vector_db", "test_vector_db_2")
|
||||
assert result is not None
|
||||
assert result.embedding_model == duplicate_vector_db.embedding_model # Original values preserved
|
||||
assert result.embedding_model == original_vector_db.embedding_model # Original values preserved
|
||||
|
||||
|
||||
async def test_get_all_objects(cached_disk_dist_registry):
|
||||
|
@ -174,14 +174,10 @@ async def test_parse_registry_values_error_handling(sqlite_kvstore):
|
|||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"),
|
||||
valid_db.model_dump_json(),
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_vector_db"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"),
|
||||
"{not valid json",
|
||||
)
|
||||
await sqlite_kvstore.set(KEY_FORMAT.format(type="vector_db", identifier="corrupted_json"), "{not valid json")
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="missing_fields"),
|
||||
|
@ -216,8 +212,7 @@ async def test_cached_registry_error_handling(sqlite_kvstore):
|
|||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"),
|
||||
valid_db.model_dump_json(),
|
||||
KEY_FORMAT.format(type="vector_db", identifier="valid_cached_db"), valid_db.model_dump_json()
|
||||
)
|
||||
|
||||
await sqlite_kvstore.set(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue