mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
Merge origin/main into add-missing-provider-data-impls
Resolved conflicts in: - benchmarking/k8s-benchmark/stack_run_config.yaml (accepted new storage schema) - llama_stack/providers/remote/inference/cerebras/cerebras.py (kept provider data support) - llama_stack/providers/remote/inference/cerebras/config.py (kept provider data support) - llama_stack/providers/remote/inference/nvidia/config.py (kept provider data support) - llama_stack/providers/remote/inference/runpod/config.py (merged imports) - pyproject.toml (kept databricks-sdk dependency)
This commit is contained in:
commit
9eb9a37ee4
1880 changed files with 804868 additions and 70533 deletions
|
|
@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
|
|||
AgentCreateResponse,
|
||||
)
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
|
||||
|
|
@ -25,6 +26,20 @@ from llama_stack.providers.inline.agents.meta_reference.config import MetaRefere
|
|||
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_backends(tmp_path):
|
||||
"""Register KV and SQL store backends for testing."""
|
||||
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
kv_path = str(tmp_path / "test_kv.db")
|
||||
sql_path = str(tmp_path / "test_sql.db")
|
||||
|
||||
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
|
||||
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
return {
|
||||
|
|
@ -33,20 +48,26 @@ def mock_apis():
|
|||
"safety_api": AsyncMock(spec=Safety),
|
||||
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
|
||||
"tool_groups_api": AsyncMock(spec=ToolGroups),
|
||||
"conversations_api": AsyncMock(spec=Conversations),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def config(tmp_path):
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
|
||||
|
||||
return MetaReferenceAgentsImplConfig(
|
||||
persistence_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
responses_store={
|
||||
"type": "sqlite",
|
||||
"db_path": str(tmp_path / "test.db"),
|
||||
},
|
||||
persistence=AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -59,7 +80,8 @@ async def agents_impl(config, mock_apis):
|
|||
mock_apis["safety_api"],
|
||||
mock_apis["tool_runtime_api"],
|
||||
mock_apis["tool_groups_api"],
|
||||
{},
|
||||
mock_apis["conversations_api"],
|
||||
[],
|
||||
)
|
||||
await impl.initialize()
|
||||
yield impl
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion_chunk import (
|
||||
|
|
@ -20,9 +20,11 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponseText,
|
||||
|
|
@ -32,15 +34,16 @@ from llama_stack.apis.agents.openai_responses import (
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIJSONSchema,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.datatypes import ResponsesStoreConfig
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
|
@ -48,7 +51,7 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
|
||||
|
||||
|
||||
|
|
@ -82,9 +85,28 @@ def mock_vector_io_api():
|
|||
return vector_io_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversations_api():
|
||||
"""Mock conversations API for testing."""
|
||||
mock_api = AsyncMock()
|
||||
return mock_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_safety_api,
|
||||
mock_conversations_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
|
|
@ -92,6 +114,8 @@ def openai_responses_impl(
|
|||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
safety_api=mock_safety_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -147,18 +171,24 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
chunks = [chunk async for chunk in result]
|
||||
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
response_format=None,
|
||||
tools=None,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
stream_options={
|
||||
"include_usage": True,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 4
|
||||
# Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 5
|
||||
assert chunks[0].type == "response.created"
|
||||
assert any(chunk.type == "response.in_progress" for chunk in chunks)
|
||||
|
||||
# Check for content part events
|
||||
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
|
||||
|
|
@ -169,6 +199,14 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
|
||||
assert len(text_delta_events) >= 1, "Should have text delta events"
|
||||
|
||||
added_event = content_part_added_events[0]
|
||||
done_event = content_part_done_events[0]
|
||||
assert added_event.content_index == 0
|
||||
assert done_event.content_index == 0
|
||||
assert added_event.output_index == done_event.output_index == 0
|
||||
assert added_event.item_id == done_event.item_id
|
||||
assert added_event.response_id == done_event.response_id
|
||||
|
||||
# Verify final event is completion
|
||||
assert chunks[-1].type == "response.completed"
|
||||
|
||||
|
|
@ -177,6 +215,8 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
assert final_response.model == model
|
||||
assert len(final_response.output) == 1
|
||||
assert isinstance(final_response.output[0], OpenAIResponseMessage)
|
||||
assert final_response.output[0].id == added_event.item_id
|
||||
assert final_response.id == added_event.response_id
|
||||
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
assert final_response.output[0].content[0].text == "Dublin"
|
||||
|
|
@ -228,13 +268,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == "What is the capital of Ireland?"
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
assert second_call.kwargs["messages"][-1].content == "Dublin"
|
||||
assert second_call.kwargs["temperature"] == 0.1
|
||||
second_params = second_call.args[0]
|
||||
assert second_params.messages[-1].content == "Dublin"
|
||||
assert second_params.temperature == 0.1
|
||||
|
||||
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
|
||||
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
|
||||
|
|
@ -303,36 +345,42 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 6
|
||||
# Should have: response.created, response.in_progress, output_item.added,
|
||||
# function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 7
|
||||
|
||||
event_types = [chunk.type for chunk in chunks]
|
||||
assert event_types == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.delta",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.delta"
|
||||
assert chunks[3].type == "response.function_call_arguments.done"
|
||||
assert chunks[4].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call)
|
||||
assert chunks[5].type == "response.completed"
|
||||
assert len(chunks[5].response.output) == 1
|
||||
assert chunks[5].response.output[0].type == "function_call"
|
||||
assert chunks[5].response.output[0].name == "get_weather"
|
||||
completed_chunk = chunks[-1]
|
||||
assert completed_chunk.type == "response.completed"
|
||||
assert len(completed_chunk.response.output) == 1
|
||||
assert completed_chunk.response.output[0].type == "function_call"
|
||||
assert completed_chunk.response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api):
|
||||
"""Test creating an OpenAI response with a tool call response that has a function that does not accept arguments, or arguments set to None when they are not mandatory."""
|
||||
# Setup
|
||||
"""Test creating an OpenAI response with tool calls that omit arguments."""
|
||||
|
||||
input_text = "What is the time right now?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
|
|
@ -359,9 +407,22 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
def assert_common_expectations(chunks) -> None:
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.tools is not None
|
||||
assert first_params.temperature == 0.1
|
||||
assert len(chunks[0].response.output) == 0
|
||||
completed_chunk = chunks[-1]
|
||||
assert completed_chunk.type == "response.completed"
|
||||
assert len(completed_chunk.response.output) == 1
|
||||
assert completed_chunk.response.output[0].type == "function_call"
|
||||
assert completed_chunk.response.output[0].name == "get_current_time"
|
||||
assert completed_chunk.response.output[0].arguments == "{}"
|
||||
|
||||
# Function does not accept arguments
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
|
|
@ -369,46 +430,23 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
temperature=0.1,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={},
|
||||
name="get_current_time", description="Get current time for system's timezone", parameters={}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 5
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.done"
|
||||
assert chunks[3].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call with arguments set to "{}")
|
||||
assert chunks[4].type == "response.completed"
|
||||
assert len(chunks[4].response.output) == 1
|
||||
assert chunks[4].response.output[0].type == "function_call"
|
||||
assert chunks[4].response.output[0].name == "get_current_time"
|
||||
assert chunks[4].response.output[0].arguments == "{}"
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
|
||||
# Function accepts optional arguments
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
|
|
@ -418,42 +456,47 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={
|
||||
"timezone": "string",
|
||||
},
|
||||
parameters={"timezone": "string"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 5
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["tools"] is not None
|
||||
assert first_call.kwargs["temperature"] == 0.1
|
||||
|
||||
# Check response.created event (should have empty output)
|
||||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.done"
|
||||
assert chunks[3].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call with arguments set to "{}")
|
||||
assert chunks[4].type == "response.completed"
|
||||
assert len(chunks[4].response.output) == 1
|
||||
assert chunks[4].response.output[0].type == "function_call"
|
||||
assert chunks[4].response.output[0].name == "get_current_time"
|
||||
assert chunks[4].response.output[0].arguments == "{}"
|
||||
# Function accepts optional arguments with additional optional fields
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(
|
||||
name="get_current_time",
|
||||
description="Get current time for system's timezone",
|
||||
parameters={"timezone": "string", "location": "string"},
|
||||
)
|
||||
],
|
||||
)
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert [chunk.type for chunk in chunks] == [
|
||||
"response.created",
|
||||
"response.in_progress",
|
||||
"response.output_item.added",
|
||||
"response.function_call_arguments.done",
|
||||
"response.output_item.done",
|
||||
"response.completed",
|
||||
]
|
||||
assert_common_expectations(chunks)
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
|
||||
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
|
|
@ -485,7 +528,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
|
|||
|
||||
# Verify the the correct messages were sent to the inference API i.e.
|
||||
# All of the responses message were convered to the chat completion message objects
|
||||
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
params = call_args.args[0]
|
||||
inference_messages = params.messages
|
||||
for i, m in enumerate(input_messages):
|
||||
if isinstance(m.content, str):
|
||||
assert inference_messages[i].content == m.content
|
||||
|
|
@ -653,7 +698,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 2
|
||||
|
|
@ -691,7 +737,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4 # 1 system + 3 input messages
|
||||
|
|
@ -751,7 +798,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.kwargs["messages"]
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
|
|
@ -767,6 +815,69 @@ async def test_create_openai_response_with_instructions_and_previous_response(
|
|||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_previous_response_instructions(
|
||||
openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test prepending instructions and previous response with instructions."""
|
||||
|
||||
input_item_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content="Name some towns in Ireland",
|
||||
role="user",
|
||||
)
|
||||
response_output_message = OpenAIResponseMessage(
|
||||
id="123",
|
||||
content="Galway, Longford, Sligo",
|
||||
status="completed",
|
||||
role="assistant",
|
||||
)
|
||||
response = _OpenAIResponseObjectWithInputAndMessages(
|
||||
created_at=1,
|
||||
id="resp_123",
|
||||
model="fake_model",
|
||||
output=[response_output_message],
|
||||
status="completed",
|
||||
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
|
||||
input=[input_item_message],
|
||||
messages=[
|
||||
OpenAIUserMessageParam(content="Name some towns in Ireland"),
|
||||
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
|
||||
],
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
mock_responses_store.get_response_object.return_value = response
|
||||
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
instructions = "You are a geography expert. Provide concise answers."
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl.create_openai_response(
|
||||
input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123"
|
||||
)
|
||||
|
||||
# Verify
|
||||
mock_inference_api.openai_chat_completion.assert_called_once()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
params = call_args.args[0]
|
||||
sent_messages = params.messages
|
||||
|
||||
# Check that instructions were prepended as a system message
|
||||
# and that the previous response instructions were not carried over
|
||||
assert len(sent_messages) == 4, sent_messages
|
||||
assert sent_messages[0].role == "system"
|
||||
assert sent_messages[0].content == instructions
|
||||
|
||||
# Check the rest of the messages were converted correctly
|
||||
assert sent_messages[1].role == "user"
|
||||
assert sent_messages[1].content == "Name some towns in Ireland"
|
||||
assert sent_messages[2].role == "assistant"
|
||||
assert sent_messages[2].content == "Galway, Longford, Sligo"
|
||||
assert sent_messages[3].role == "user"
|
||||
assert sent_messages[3].content == "Which is the largest?"
|
||||
|
||||
|
||||
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
|
||||
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
|
||||
# Setup
|
||||
|
|
@ -807,8 +918,10 @@ async def test_responses_store_list_input_items_logic():
|
|||
|
||||
# Create mock store and response store
|
||||
mock_sql_store = AsyncMock()
|
||||
backend_name = "sql_responses_test"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
|
||||
responses_store = ResponsesStore(
|
||||
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
|
||||
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
|
||||
)
|
||||
responses_store.sql_store = mock_sql_store
|
||||
|
||||
|
|
@ -953,6 +1066,58 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
|
|||
assert result.status == "completed"
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
|
||||
async def test_reuse_mcp_tool_list(
|
||||
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
|
||||
):
|
||||
"""Test that mcp_list_tools can be reused where appropriate."""
|
||||
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
mock_list_mcp_tools.return_value = ListToolDefsResponse(
|
||||
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
|
||||
)
|
||||
|
||||
res1 = await openai_responses_impl.create_openai_response(
|
||||
input="What is 2+2?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
args = mock_responses_store.store_response_object.call_args
|
||||
data = args.kwargs["response_object"].model_dump()
|
||||
data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]]
|
||||
data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]]
|
||||
stored = _OpenAIResponseObjectWithInputAndMessages(**data)
|
||||
mock_responses_store.get_response_object.return_value = stored
|
||||
|
||||
res2 = await openai_responses_impl.create_openai_response(
|
||||
previous_response_id=res1.id,
|
||||
input="Now what is 3+3?",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
store=True,
|
||||
tools=[
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
],
|
||||
)
|
||||
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
|
||||
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
|
||||
second_params = second_call.args[0]
|
||||
tools_seen = second_params.tools
|
||||
assert len(tools_seen) == 1
|
||||
assert tools_seen[0]["function"]["name"] == "test_tool"
|
||||
assert tools_seen[0]["function"]["description"] == "a test tool"
|
||||
|
||||
assert mock_list_mcp_tools.call_count == 1
|
||||
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
|
||||
assert len(listings) == 1
|
||||
assert listings[0].server_label == "alabel"
|
||||
assert len(listings[0].tools) == 1
|
||||
assert listings[0].tools[0].name == "test_tool"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text_format, response_format",
|
||||
[
|
||||
|
|
@ -987,8 +1152,9 @@ async def test_create_openai_response_with_text_format(
|
|||
|
||||
# Verify
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
assert first_call.kwargs["messages"][0].content == input_text
|
||||
assert first_call.kwargs["response_format"] == response_format
|
||||
first_params = first_call.args[0]
|
||||
assert first_params.messages[0].content == input_text
|
||||
assert first_params.response_format == response_format
|
||||
|
||||
|
||||
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
|
||||
|
|
@ -1004,3 +1170,75 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
|
|||
model=model,
|
||||
text=OpenAIResponseText(format={"type": "invalid"}),
|
||||
)
|
||||
|
||||
|
||||
async def test_create_openai_response_with_output_types_as_input(
|
||||
openai_responses_impl, mock_inference_api, mock_responses_store
|
||||
):
|
||||
"""Test that response outputs can be used as inputs in multi-turn conversations.
|
||||
|
||||
Before adding OpenAIResponseOutput types to OpenAIResponseInput,
|
||||
creating a _OpenAIResponseObjectWithInputAndMessages with some output types
|
||||
in the input field would fail with a Pydantic ValidationError.
|
||||
|
||||
This test simulates storing a response where the input contains output message
|
||||
types (MCP calls, function calls), which happens in multi-turn conversations.
|
||||
"""
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
# Mock the inference response
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
# Create a response with store=True to trigger the storage path
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input="What's the weather?",
|
||||
model=model,
|
||||
stream=True,
|
||||
temperature=0.1,
|
||||
store=True,
|
||||
)
|
||||
|
||||
# Consume the stream
|
||||
_ = [chunk async for chunk in result]
|
||||
|
||||
# Verify store was called
|
||||
assert mock_responses_store.store_response_object.called
|
||||
|
||||
# Get the stored data
|
||||
store_call_args = mock_responses_store.store_response_object.call_args
|
||||
stored_response = store_call_args.kwargs["response_object"]
|
||||
|
||||
# Now simulate a multi-turn conversation where outputs become inputs
|
||||
input_with_output_types = [
|
||||
OpenAIResponseMessage(role="user", content="What's the weather?", name=None),
|
||||
# These output types need to be valid OpenAIResponseInput
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_123",
|
||||
name="get_weather",
|
||||
arguments='{"city": "Tokyo"}',
|
||||
type="function_call",
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPCall(
|
||||
id="mcp_456",
|
||||
type="mcp_call",
|
||||
server_label="weather_server",
|
||||
name="get_temperature",
|
||||
arguments='{"location": "Tokyo"}',
|
||||
output="25°C",
|
||||
),
|
||||
]
|
||||
|
||||
# This simulates storing a response in a multi-turn conversation
|
||||
# where previous outputs are included in the input.
|
||||
stored_with_outputs = _OpenAIResponseObjectWithInputAndMessages(
|
||||
id=stored_response.id,
|
||||
created_at=stored_response.created_at,
|
||||
model=stored_response.model,
|
||||
status=stored_response.status,
|
||||
output=stored_response.output,
|
||||
input=input_with_output_types, # This will trigger Pydantic validation
|
||||
messages=None,
|
||||
)
|
||||
|
||||
assert stored_with_outputs.input == input_with_output_types
|
||||
assert len(stored_with_outputs.input) == 3
|
||||
|
|
|
|||
|
|
@ -0,0 +1,249 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseObjectStreamResponseCompleted,
|
||||
OpenAIResponseObjectStreamResponseOutputItemDone,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
)
|
||||
from llama_stack.apis.common.errors import (
|
||||
ConversationNotFoundError,
|
||||
InvalidConversationIdError,
|
||||
)
|
||||
from llama_stack.apis.conversations.conversations import (
|
||||
ConversationItemList,
|
||||
)
|
||||
|
||||
# Import existing fixtures from the main responses test file
|
||||
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl_with_conversations(
|
||||
mock_inference_api,
|
||||
mock_tool_groups_api,
|
||||
mock_tool_runtime_api,
|
||||
mock_responses_store,
|
||||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
tool_groups_api=mock_tool_groups_api,
|
||||
tool_runtime_api=mock_tool_runtime_api,
|
||||
responses_store=mock_responses_store,
|
||||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
)
|
||||
|
||||
|
||||
class TestConversationValidation:
|
||||
"""Test conversation ID validation logic."""
|
||||
|
||||
async def test_nonexistent_conversation_raises_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
|
||||
conv_id = "conv_nonexistent"
|
||||
|
||||
# Mock conversation not found
|
||||
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError):
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation=conv_id, stream=False
|
||||
)
|
||||
|
||||
|
||||
class TestMessageSyncing:
|
||||
"""Test message syncing to conversations."""
|
||||
|
||||
async def test_sync_response_to_conversation_simple(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test syncing simple response to conversation."""
|
||||
conv_id = "conv_test123"
|
||||
input_text = "What are the 5 Ds of dodgeball?"
|
||||
|
||||
# Output items (what the model generated)
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
|
||||
|
||||
# should call add_items with user input and assistant response
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
assert call_args[0][0] == conv_id # conversation_id
|
||||
items = call_args[0][1] # conversation_items
|
||||
|
||||
assert len(items) == 2
|
||||
# User message
|
||||
assert items[0].type == "message"
|
||||
assert items[0].role == "user"
|
||||
assert items[0].content[0].type == "input_text"
|
||||
assert items[0].content[0].text == input_text
|
||||
|
||||
# Assistant message
|
||||
assert items[1].type == "message"
|
||||
assert items[1].role == "assistant"
|
||||
|
||||
async def test_sync_response_to_conversation_api_error(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
mock_conversations_api.add_items.side_effect = Exception("API Error")
|
||||
output_items = []
|
||||
|
||||
# matching the behavior of OpenAI here
|
||||
with pytest.raises(Exception, match="API Error"):
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(
|
||||
"conv_test123", "Hello", output_items
|
||||
)
|
||||
|
||||
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
|
||||
"""Test syncing with list of input messages."""
|
||||
conv_id = "conv_test123"
|
||||
input_messages = [
|
||||
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
|
||||
]
|
||||
output_items = [
|
||||
OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
]
|
||||
|
||||
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
|
||||
|
||||
mock_conversations_api.add_items.assert_called_once()
|
||||
call_args = mock_conversations_api.add_items.call_args
|
||||
|
||||
items = call_args[0][1]
|
||||
# Should have input message + output message
|
||||
assert len(items) == 2
|
||||
|
||||
|
||||
class TestIntegrationWorkflow:
|
||||
"""Integration tests for the full conversation workflow."""
|
||||
|
||||
async def test_create_response_with_valid_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a valid conversation parameter."""
|
||||
mock_conversations_api.list_items.return_value = ConversationItemList(
|
||||
data=[], first_id=None, has_more=False, last_id=None, object="list"
|
||||
)
|
||||
|
||||
async def mock_streaming_response(*args, **kwargs):
|
||||
message_item = OpenAIResponseMessage(
|
||||
id="msg_response",
|
||||
content=[
|
||||
OpenAIResponseOutputMessageContentOutputText(
|
||||
text="Test response", type="output_text", annotations=[]
|
||||
)
|
||||
],
|
||||
role="assistant",
|
||||
status="completed",
|
||||
type="message",
|
||||
)
|
||||
|
||||
# Emit output_item.done event first (needed for conversation sync)
|
||||
yield OpenAIResponseObjectStreamResponseOutputItemDone(
|
||||
response_id="resp_test123",
|
||||
item=message_item,
|
||||
output_index=0,
|
||||
sequence_number=1,
|
||||
type="response.output_item.done",
|
||||
)
|
||||
|
||||
# Then emit response.completed
|
||||
mock_response = OpenAIResponseObject(
|
||||
id="resp_test123",
|
||||
created_at=1234567890,
|
||||
model="test-model",
|
||||
object="response",
|
||||
output=[message_item],
|
||||
status="completed",
|
||||
)
|
||||
|
||||
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
|
||||
|
||||
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
|
||||
|
||||
input_text = "Hello, how are you?"
|
||||
conversation_id = "conv_test123"
|
||||
|
||||
response = await responses_impl_with_conversations.create_openai_response(
|
||||
input=input_text, model="test-model", conversation=conversation_id, stream=False
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
assert response.id == "resp_test123"
|
||||
|
||||
# Note: conversation sync happens inside _create_streaming_response,
|
||||
# which we're mocking here, so we can't test it in this unit test.
|
||||
# The sync logic is tested separately in TestMessageSyncing.
|
||||
|
||||
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
|
||||
"""Test creating a response with an invalid conversation ID."""
|
||||
with pytest.raises(InvalidConversationIdError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="invalid_id", stream=False
|
||||
)
|
||||
|
||||
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
|
||||
|
||||
async def test_create_response_with_nonexistent_conversation(
|
||||
self, responses_impl_with_conversations, mock_conversations_api
|
||||
):
|
||||
"""Test creating a response with a non-existent conversation."""
|
||||
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
|
||||
|
||||
with pytest.raises(ConversationNotFoundError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
|
||||
)
|
||||
|
||||
assert "not found" in str(exc_info.value)
|
||||
|
||||
async def test_conversation_and_previous_response_id(
|
||||
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
|
||||
):
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await responses_impl_with_conversations.create_openai_response(
|
||||
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
|
||||
)
|
||||
|
||||
assert "Mutually exclusive parameters" in str(exc_info.value)
|
||||
assert "previous_response_id" in str(exc_info.value)
|
||||
assert "conversation" in str(exc_info.value)
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
MCPListToolsTool,
|
||||
OpenAIResponseInputToolFileSearch,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseObject,
|
||||
OpenAIResponseOutputMessageMCPListTools,
|
||||
OpenAIResponseToolMCP,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
|
||||
|
||||
|
||||
class TestToolContext:
|
||||
def test_no_tools(self):
|
||||
tools = []
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 0
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_no_previous_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert len(context.previous_tools) == 0
|
||||
assert len(context.previous_tool_listings) == 0
|
||||
|
||||
def test_reusable_server(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
|
||||
OpenAIResponseToolMCP(server_label="alabel"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 1
|
||||
assert context.tools_to_process[0].type == "file_search"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
|
||||
def test_multiple_reusable_servers(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 2
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert len(context.previous_tools) == 2
|
||||
assert context.previous_tools["test_tool"].server_label == "alabel"
|
||||
assert context.previous_tools["test_tool"].server_url == "aurl"
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 2
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "alabel"
|
||||
assert len(context.previous_tool_listings[1].tools) == 1
|
||||
assert context.previous_tool_listings[1].server_label == "anotherlabel"
|
||||
|
||||
def test_multiple_servers_only_one_reusable(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
)
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
||||
def test_mismatched_allowed_tools(self):
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
|
||||
]
|
||||
context = ToolContext(tools)
|
||||
output = [
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
|
||||
),
|
||||
OpenAIResponseOutputMessageMCPListTools(
|
||||
id="test2",
|
||||
server_label="anotherlabel",
|
||||
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
|
||||
),
|
||||
]
|
||||
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
|
||||
previous_response.tools = [
|
||||
OpenAIResponseInputToolFunction(name="fake", parameters=None),
|
||||
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
|
||||
]
|
||||
context.recover_tools_from_previous_response(previous_response)
|
||||
|
||||
assert len(context.tools_to_process) == 3
|
||||
assert context.tools_to_process[0].type == "function"
|
||||
assert context.tools_to_process[1].type == "web_search"
|
||||
assert context.tools_to_process[2].type == "mcp"
|
||||
assert len(context.previous_tools) == 1
|
||||
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
|
||||
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
|
||||
assert len(context.previous_tool_listings) == 1
|
||||
assert len(context.previous_tool_listings[0].tools) == 1
|
||||
assert context.previous_tool_listings[0].server_label == "anotherlabel"
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
|
||||
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
extract_guardrail_ids,
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_apis():
|
||||
"""Create mock APIs for testing."""
|
||||
return {
|
||||
"inference_api": AsyncMock(),
|
||||
"tool_groups_api": AsyncMock(),
|
||||
"tool_runtime_api": AsyncMock(),
|
||||
"responses_store": AsyncMock(),
|
||||
"vector_io_api": AsyncMock(),
|
||||
"conversations_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def responses_impl(mock_apis):
|
||||
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
|
||||
return OpenAIResponsesImpl(**mock_apis)
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_strings(responses_impl):
|
||||
"""Test extraction from simple string guardrail IDs."""
|
||||
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_from_objects(responses_impl):
|
||||
"""Test extraction from ResponseGuardrailSpec objects."""
|
||||
guardrails = [
|
||||
ResponseGuardrailSpec(type="llama-guard"),
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_mixed_formats(responses_impl):
|
||||
"""Test extraction from mixed string and object formats."""
|
||||
guardrails = [
|
||||
"llama-guard",
|
||||
ResponseGuardrailSpec(type="content-filter"),
|
||||
"nsfw-detector",
|
||||
]
|
||||
result = extract_guardrail_ids(guardrails)
|
||||
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_none_input(responses_impl):
|
||||
"""Test extraction with None input."""
|
||||
result = extract_guardrail_ids(None)
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_empty_list(responses_impl):
|
||||
"""Test extraction with empty list."""
|
||||
result = extract_guardrail_ids([])
|
||||
assert result == []
|
||||
|
||||
|
||||
def test_extract_guardrail_ids_unknown_format(responses_impl):
|
||||
"""Test extraction with unknown guardrail format raises ValueError."""
|
||||
# Create an object that's neither string nor ResponseGuardrailSpec
|
||||
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
|
||||
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
|
||||
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
|
||||
extract_guardrail_ids(guardrails)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
"""Create mock safety API for guardrails testing."""
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
return safety_api
|
||||
|
||||
|
||||
async def test_run_guardrails_no_violation(mock_safety_api):
|
||||
"""Test guardrails validation with no violations."""
|
||||
text = "Hello world"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return non-flagged content
|
||||
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result is None
|
||||
# Verify run_moderation was called with the correct model
|
||||
mock_safety_api.run_moderation.assert_called_once()
|
||||
call_args = mock_safety_api.run_moderation.call_args
|
||||
assert call_args[1]["model"] == "llama-guard-model"
|
||||
|
||||
|
||||
async def test_run_guardrails_with_violation(mock_safety_api):
|
||||
"""Test guardrails validation with safety violation."""
|
||||
text = "Harmful content"
|
||||
guardrail_ids = ["llama-guard"]
|
||||
|
||||
# Mock moderation to return flagged content
|
||||
flagged_result = ModerationObjectResults(
|
||||
flagged=True,
|
||||
categories={"violence": True},
|
||||
user_message="Content flagged by moderation",
|
||||
metadata={"violation_type": ["S1"]},
|
||||
)
|
||||
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
|
||||
mock_safety_api.run_moderation.return_value = mock_moderation_object
|
||||
|
||||
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
|
||||
|
||||
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
|
||||
|
||||
|
||||
async def test_run_guardrails_empty_inputs(mock_safety_api):
|
||||
"""Test guardrails validation with empty inputs."""
|
||||
# Test empty guardrail_ids
|
||||
result = await run_guardrails(mock_safety_api, "test", [])
|
||||
assert result is None
|
||||
|
||||
# Test empty text
|
||||
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test both empty
|
||||
result = await run_guardrails(mock_safety_api, "", [])
|
||||
assert result is None
|
||||
|
|
@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
|
|||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -23,8 +23,10 @@ async def provider():
|
|||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
backend_name = "kv_batches_test"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
register_kvstore_backends({backend_name: kvstore_config})
|
||||
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
|
|
|
|||
|
|
@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
|
|||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
|
|
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
|
|||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
||||
|
||||
async def test_create_batch_embeddings_endpoint(self, provider):
|
||||
"""Test that batch creation succeeds with embeddings endpoint."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/embeddings",
|
||||
completion_window="24h",
|
||||
)
|
||||
assert batch.endpoint == "/v1/embeddings"
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import boto3
|
|||
import pytest
|
||||
from moto import mock_aws
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
@ -38,11 +39,13 @@ def sample_text_file2():
|
|||
def s3_config(tmp_path):
|
||||
db_path = tmp_path / "s3_files_metadata.db"
|
||||
|
||||
backend_name = f"sql_s3_{tmp_path.name}"
|
||||
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
|
||||
return S3FilesImplConfig(
|
||||
bucket_name=f"test-bucket-{tmp_path.name}",
|
||||
region="not-a-region",
|
||||
auto_create_bucket=True,
|
||||
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
|
||||
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -15,16 +15,16 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
|
|||
|
||||
|
||||
# Test fixtures and helper classes
|
||||
class TestConfig(BaseModel):
|
||||
class FakeConfig(BaseModel):
|
||||
api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TestProviderDataValidator(BaseModel):
|
||||
class FakeProviderDataValidator(BaseModel):
|
||||
test_api_key: str | None = Field(default=None)
|
||||
|
||||
|
||||
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: TestConfig):
|
||||
class FakeLiteLLMAdapter(LiteLLMOpenAIMixin):
|
||||
def __init__(self, config: FakeConfig):
|
||||
super().__init__(
|
||||
litellm_provider_name="test",
|
||||
api_key_from_config=config.api_key,
|
||||
|
|
@ -36,11 +36,11 @@ class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
|
|||
@pytest.fixture
|
||||
def adapter_with_config_key():
|
||||
"""Fixture to create adapter with API key in config"""
|
||||
config = TestConfig(api_key="config-api-key")
|
||||
adapter = TestLiteLLMAdapter(config)
|
||||
config = FakeConfig(api_key="config-api-key")
|
||||
adapter = FakeLiteLLMAdapter(config)
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
adapter.__provider_spec__.provider_data_validator = (
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
|
@ -48,11 +48,11 @@ def adapter_with_config_key():
|
|||
@pytest.fixture
|
||||
def adapter_without_config_key():
|
||||
"""Fixture to create adapter without API key in config"""
|
||||
config = TestConfig(api_key=None)
|
||||
adapter = TestLiteLLMAdapter(config)
|
||||
config = FakeConfig(api_key=None)
|
||||
adapter = FakeLiteLLMAdapter(config)
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
adapter.__provider_spec__.provider_data_validator = (
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
|
||||
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
|
||||
)
|
||||
return adapter
|
||||
|
||||
|
|
|
|||
|
|
@ -13,10 +13,16 @@ import pytest
|
|||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletion,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAIChoice,
|
||||
OpenAICompletion,
|
||||
OpenAICompletionChoice,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
ToolChoice,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.core.routers.inference import InferenceRouter
|
||||
from llama_stack.core.routing_tables.models import ModelsRoutingTable
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
|
||||
|
|
@ -56,13 +62,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
|||
mock_client_property.return_value = mock_client
|
||||
|
||||
# No tools but auto tool choice
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
"mock-model",
|
||||
[],
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=False,
|
||||
tools=None,
|
||||
tool_choice=ToolChoice.auto.value,
|
||||
)
|
||||
await vllm_inference_adapter.openai_chat_completion(params)
|
||||
mock_client.chat.completions.create.assert_called()
|
||||
call_args = mock_client.chat.completions.create.call_args
|
||||
# Ensure tool_choice gets converted to none for older vLLM versions
|
||||
|
|
@ -171,9 +178,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
|||
)
|
||||
|
||||
async def do_inference():
|
||||
await vllm_inference_adapter.openai_chat_completion(
|
||||
"mock-model", messages=["one fish", "two fish"], stream=False
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "one fish two fish"}],
|
||||
stream=False,
|
||||
)
|
||||
await vllm_inference_adapter.openai_chat_completion(params)
|
||||
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
|
||||
mock_client = MagicMock()
|
||||
|
|
@ -186,3 +196,148 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
|
|||
|
||||
assert mock_create_client.call_count == 4 # no cheating
|
||||
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
|
||||
|
||||
|
||||
async def test_vllm_completion_extra_body():
|
||||
"""
|
||||
Test that vLLM-specific guided_choice and prompt_logprobs parameters are correctly forwarded
|
||||
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
||||
# Create a mock model store
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
|
||||
mock_model_store.get_model.return_value = mock_model
|
||||
mock_model_store.has_model.return_value = True
|
||||
|
||||
# Create a mock dist_registry
|
||||
mock_dist_registry = MagicMock()
|
||||
mock_dist_registry.get = AsyncMock(return_value=mock_model)
|
||||
mock_dist_registry.set = AsyncMock()
|
||||
|
||||
# Set up the routing table
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"vllm": vllm_adapter},
|
||||
dist_registry=mock_dist_registry,
|
||||
policy=[],
|
||||
)
|
||||
# Inject the model store into the adapter
|
||||
vllm_adapter.model_store = routing_table
|
||||
|
||||
# Create the InferenceRouter
|
||||
router = InferenceRouter(routing_table=routing_table)
|
||||
|
||||
# Patch the OpenAI client
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
|
||||
mock_client = MagicMock()
|
||||
mock_client.completions.create = AsyncMock(
|
||||
return_value=OpenAICompletion(
|
||||
id="cmpl-abc123",
|
||||
created=1,
|
||||
model="mock-model",
|
||||
choices=[
|
||||
OpenAICompletionChoice(
|
||||
text="joy",
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
mock_client_property.return_value = mock_client
|
||||
|
||||
# Test with guided_choice and prompt_logprobs as extra fields
|
||||
params = OpenAICompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
prompt="I am feeling happy",
|
||||
stream=False,
|
||||
guided_choice=["joy", "sadness"],
|
||||
prompt_logprobs=5,
|
||||
)
|
||||
await router.openai_completion(params)
|
||||
|
||||
# Verify that the client was called with extra_body containing both parameters
|
||||
mock_client.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.completions.create.call_args.kwargs
|
||||
assert "extra_body" in call_kwargs
|
||||
assert "guided_choice" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"]
|
||||
assert "prompt_logprobs" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["prompt_logprobs"] == 5
|
||||
|
||||
|
||||
async def test_vllm_chat_completion_extra_body():
|
||||
"""
|
||||
Test that vLLM-specific parameters (e.g., chat_template_kwargs) are correctly forwarded
|
||||
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
||||
# Create a mock model store
|
||||
mock_model_store = AsyncMock()
|
||||
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
|
||||
mock_model_store.get_model.return_value = mock_model
|
||||
mock_model_store.has_model.return_value = True
|
||||
|
||||
# Create a mock dist_registry
|
||||
mock_dist_registry = MagicMock()
|
||||
mock_dist_registry.get = AsyncMock(return_value=mock_model)
|
||||
mock_dist_registry.set = AsyncMock()
|
||||
|
||||
# Set up the routing table
|
||||
routing_table = ModelsRoutingTable(
|
||||
impls_by_provider_id={"vllm": vllm_adapter},
|
||||
dist_registry=mock_dist_registry,
|
||||
policy=[],
|
||||
)
|
||||
# Inject the model store into the adapter
|
||||
vllm_adapter.model_store = routing_table
|
||||
|
||||
# Create the InferenceRouter
|
||||
router = InferenceRouter(routing_table=routing_table)
|
||||
|
||||
# Patch the OpenAI client
|
||||
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(
|
||||
return_value=OpenAIChatCompletion(
|
||||
id="chatcmpl-abc123",
|
||||
created=1,
|
||||
model="mock-model",
|
||||
choices=[
|
||||
OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content="test response",
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
mock_client_property.return_value = mock_client
|
||||
|
||||
# Test with chat_template_kwargs as extra field
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="mock-model",
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
stream=False,
|
||||
chat_template_kwargs={"thinking": True},
|
||||
)
|
||||
await router.openai_chat_completion(params)
|
||||
|
||||
# Verify that the client was called with extra_body containing chat_template_kwargs
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
|
||||
assert "extra_body" in call_kwargs
|
||||
assert "chat_template_kwargs" in call_kwargs["extra_body"]
|
||||
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}
|
||||
|
|
|
|||
|
|
@ -4,10 +4,45 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
|
||||
convert_tooldef_to_chat_tool,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_safety_api():
|
||||
safety_api = AsyncMock()
|
||||
# Mock the routing table and shields list for guardrails lookup
|
||||
safety_api.routing_table = AsyncMock()
|
||||
shield = AsyncMock()
|
||||
shield.identifier = "llama-guard"
|
||||
shield.provider_resource_id = "llama-guard-model"
|
||||
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
|
||||
# Mock run_moderation to return non-flagged result by default
|
||||
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
|
||||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_inference_api():
|
||||
inference_api = AsyncMock()
|
||||
return inference_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
context = AsyncMock(spec=ChatCompletionContext)
|
||||
# Add required attributes that StreamingResponseOrchestrator expects
|
||||
context.tool_context = AsyncMock()
|
||||
context.tool_context.previous_tools = {}
|
||||
context.messages = []
|
||||
return context
|
||||
|
||||
|
||||
def test_convert_tooldef_to_chat_tool_preserves_items_field():
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
|
|||
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
|
||||
|
||||
|
||||
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
class FakeNVIDIASafetyAdapter(NVIDIASafetyAdapter):
|
||||
"""Test implementation that provides the required shield_store."""
|
||||
|
||||
def __init__(self, config: NVIDIASafetyConfig, shield_store):
|
||||
|
|
@ -41,7 +41,7 @@ def nvidia_adapter():
|
|||
shield_store = AsyncMock()
|
||||
shield_store.get_shield = AsyncMock()
|
||||
|
||||
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||
adapter = FakeNVIDIASafetyAdapter(config=config, shield_store=shield_store)
|
||||
|
||||
return adapter
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
|||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
|
||||
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
|
|
@ -23,10 +23,10 @@ class OpenAIMixinImpl(OpenAIMixin):
|
|||
__provider_id__: str = "test-provider"
|
||||
|
||||
def get_api_key(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
return "test-api-key"
|
||||
|
||||
def get_base_url(self) -> str:
|
||||
raise NotImplementedError("This method should be mocked in tests")
|
||||
return "http://test-base-url"
|
||||
|
||||
|
||||
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
||||
|
|
@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
|
|||
}
|
||||
|
||||
|
||||
class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl):
|
||||
"""Test implementation that uses construct_model_from_identifier to add rerank models"""
|
||||
|
||||
embedding_model_metadata: dict[str, dict[str, int]] = {
|
||||
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
}
|
||||
|
||||
# Adds rerank models via construct_model_from_identifier
|
||||
rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"}
|
||||
|
||||
def construct_model_from_identifier(self, identifier: str) -> Model:
|
||||
if identifier in self.rerank_model_ids:
|
||||
return Model(
|
||||
provider_id=self.__provider_id__, # type: ignore[attr-defined]
|
||||
provider_resource_id=identifier,
|
||||
identifier=identifier,
|
||||
model_type=ModelType.rerank,
|
||||
)
|
||||
return super().construct_model_from_identifier(identifier)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin():
|
||||
"""Create a test instance of OpenAIMixin with mocked model_store"""
|
||||
|
|
@ -62,6 +84,13 @@ def mixin_with_embeddings():
|
|||
return OpenAIMixinWithEmbeddingsImpl(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mixin_with_custom_model_construction():
|
||||
"""Create a test instance using custom construct_model_from_identifier"""
|
||||
config = RemoteInferenceProviderConfig()
|
||||
return OpenAIMixinWithCustomModelConstruction(config=config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_models():
|
||||
"""Create multiple mock OpenAI model objects"""
|
||||
|
|
@ -113,6 +142,19 @@ def mock_client_context():
|
|||
return _mock_client_context
|
||||
|
||||
|
||||
def _assert_models_match_expected(actual_models, expected_models):
|
||||
"""Verify the models match expected attributes.
|
||||
|
||||
Args:
|
||||
actual_models: List of models to verify
|
||||
expected_models: Mapping of model identifier to expected attribute values
|
||||
"""
|
||||
for identifier, expected_attrs in expected_models.items():
|
||||
model = next(m for m in actual_models if m.identifier == identifier)
|
||||
for attr_name, expected_value in expected_attrs.items():
|
||||
assert getattr(model, attr_name) == expected_value
|
||||
|
||||
|
||||
class TestOpenAIMixinListModels:
|
||||
"""Test cases for the list_models method"""
|
||||
|
||||
|
|
@ -205,7 +247,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
assert await mixin.check_model_availability("pre-registered-model")
|
||||
# Should not call the provider's list_models since model was found in store
|
||||
mock_client_with_models.models.list.assert_not_called()
|
||||
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
|
||||
mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model")
|
||||
|
||||
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
|
||||
self, mixin, mock_client_with_models, mock_client_context
|
||||
|
|
@ -222,7 +264,7 @@ class TestOpenAIMixinCheckModelAvailability:
|
|||
assert await mixin.check_model_availability("some-mock-model-id")
|
||||
# Should call the provider's list_models since model was not found in store
|
||||
mock_client_with_models.models.list.assert_called_once()
|
||||
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
|
||||
mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id")
|
||||
|
||||
|
||||
class TestOpenAIMixinCacheBehavior:
|
||||
|
|
@ -271,7 +313,8 @@ class TestOpenAIMixinImagePreprocessing:
|
|||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
mock_localize.return_value = (b"fake_image_data", "jpeg")
|
||||
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
|
||||
await mixin.openai_chat_completion(params)
|
||||
|
||||
mock_localize.assert_called_once_with("http://example.com/image.jpg")
|
||||
|
||||
|
|
@ -303,7 +346,8 @@ class TestOpenAIMixinImagePreprocessing:
|
|||
|
||||
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
|
||||
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
|
||||
await mixin.openai_chat_completion(model="test-model", messages=[message])
|
||||
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
|
||||
await mixin.openai_chat_completion(params)
|
||||
|
||||
mock_localize.assert_not_called()
|
||||
|
||||
|
|
@ -340,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata:
|
|||
assert result is not None
|
||||
assert len(result) == 2
|
||||
|
||||
# Find the models in the result
|
||||
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
|
||||
llm_model = next(m for m in result if m.identifier == "gpt-4")
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
# Check embedding model
|
||||
assert embedding_model.model_type == ModelType.embedding
|
||||
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
|
||||
assert embedding_model.provider_id == "test-provider"
|
||||
assert embedding_model.provider_resource_id == "text-embedding-3-small"
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
# Check LLM model
|
||||
assert llm_model.model_type == ModelType.llm
|
||||
assert llm_model.metadata == {} # No metadata for LLMs
|
||||
assert llm_model.provider_id == "test-provider"
|
||||
assert llm_model.provider_resource_id == "gpt-4"
|
||||
|
||||
class TestOpenAIMixinCustomModelConstruction:
|
||||
"""Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier"""
|
||||
|
||||
async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context):
|
||||
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
|
||||
# Create mock models: 1 embedding, 1 rerank, 1 LLM
|
||||
mock_embedding_model = MagicMock(id="text-embedding-3-small")
|
||||
mock_rerank_model = MagicMock(id="rerank-model-1")
|
||||
mock_llm_model = MagicMock(id="gpt-4")
|
||||
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
async def mock_models_list():
|
||||
for model in mock_models:
|
||||
yield model
|
||||
|
||||
mock_client.models.list.return_value = mock_models_list()
|
||||
|
||||
with mock_client_context(mixin_with_custom_model_construction, mock_client):
|
||||
result = await mixin_with_custom_model_construction.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 3
|
||||
|
||||
expected_models = {
|
||||
"text-embedding-3-small": {
|
||||
"model_type": ModelType.embedding,
|
||||
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "text-embedding-3-small",
|
||||
},
|
||||
"rerank-model-1": {
|
||||
"model_type": ModelType.rerank,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "rerank-model-1",
|
||||
},
|
||||
"gpt-4": {
|
||||
"model_type": ModelType.llm,
|
||||
"metadata": {},
|
||||
"provider_id": "test-provider",
|
||||
"provider_resource_id": "gpt-4",
|
||||
},
|
||||
}
|
||||
|
||||
_assert_models_match_expected(result, expected_models)
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModels:
|
||||
|
|
@ -720,7 +814,7 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
):
|
||||
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
|
||||
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
|
||||
with pytest.raises(ValueError, match="API key is not set"):
|
||||
with pytest.raises(ValueError, match="API key not provided"):
|
||||
_ = mixin_with_provider_data_field_and_none_api_key.client
|
||||
|
||||
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.stack import replace_env_vars
|
||||
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
|
||||
from llama_stack.providers.remote.inference.azure import AzureConfig
|
||||
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
|
||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
|
||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||
from llama_stack.providers.remote.inference.gemini import GeminiConfig
|
||||
from llama_stack.providers.remote.inference.groq import GroqConfig
|
||||
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
|
||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||
from llama_stack.providers.remote.inference.openai import OpenAIConfig
|
||||
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
|
||||
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
|
||||
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
|
||||
from llama_stack.providers.remote.inference.together import TogetherImplConfig
|
||||
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
|
||||
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
|
||||
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||
|
||||
|
||||
class TestRemoteInferenceProviderConfig:
|
||||
@pytest.mark.parametrize(
|
||||
"config_cls,alias_name,env_name,extra_config",
|
||||
[
|
||||
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
|
||||
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
|
||||
(BedrockConfig, None, None, {}),
|
||||
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
|
||||
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
|
||||
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
|
||||
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
|
||||
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
|
||||
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
|
||||
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
|
||||
(OllamaImplConfig, None, None, {}),
|
||||
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
|
||||
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
|
||||
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
|
||||
(TGIImplConfig, None, None, {"url": "FAKE"}),
|
||||
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
|
||||
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
|
||||
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
|
||||
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
|
||||
],
|
||||
)
|
||||
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
|
||||
"""Test that the config class correctly maps the alias to auth_credential."""
|
||||
secret_value = config_cls.__name__
|
||||
|
||||
if alias_name is None:
|
||||
pytest.skip("No alias name provided for this config class.")
|
||||
|
||||
config = config_cls(**{alias_name: secret_value, **extra_config})
|
||||
assert config.auth_credential is not None
|
||||
assert config.auth_credential.get_secret_value() == secret_value
|
||||
|
||||
schema = config_cls.model_json_schema()
|
||||
assert alias_name in schema["properties"]
|
||||
assert "auth_credential" not in schema["properties"]
|
||||
|
||||
if env_name:
|
||||
monkeypatch.setenv(env_name, secret_value)
|
||||
sample_config = config_cls.sample_run_config()
|
||||
expanded_config = replace_env_vars(sample_config)
|
||||
config_from_sample = config_cls(**{**expanded_config, **extra_config})
|
||||
assert config_from_sample.auth_credential is not None
|
||||
assert config_from_sample.auth_credential.get_secret_value() == secret_value
|
||||
|
|
@ -9,38 +9,29 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from chromadb import PersistentClient
|
||||
from pymilvus import MilvusClient, connections
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
|
||||
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
|
||||
EMBEDDING_DIMENSION = 384
|
||||
EMBEDDING_DIMENSION = 768
|
||||
COLLECTION_PREFIX = "test_collection"
|
||||
MILVUS_ALIAS = "test_milvus"
|
||||
|
||||
|
||||
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
|
||||
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
|
||||
def vector_provider(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id() -> str:
|
||||
def vector_store_id() -> str:
|
||||
return f"test-vector-db-{random.randint(1, 100)}"
|
||||
|
||||
|
||||
|
|
@ -122,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
|
|||
unique_id = f"test_kv_{np.random.randint(1e6)}"
|
||||
temp_dir = tmp_path_factory.getbasetemp()
|
||||
db_path = str(temp_dir / f"{unique_id}.db")
|
||||
|
||||
return SqliteKVStoreConfig(db_path=db_path)
|
||||
backend_name = f"kv_vector_{unique_id}"
|
||||
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
|
||||
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -148,7 +140,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
|
|||
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = SQLiteVectorIOConfig(
|
||||
db_path=sqlite_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = SQLiteVecVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -157,8 +149,8 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
)
|
||||
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -170,46 +162,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def milvus_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
|
||||
client = MilvusClient(milvus_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
|
||||
index = MilvusIndex(client, name, consistency_level="Strong")
|
||||
index.db_path = milvus_vec_db_path
|
||||
yield index
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
|
||||
config = MilvusVectorIOConfig(
|
||||
db_path=milvus_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = MilvusVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=adapter.metadata_collection_name,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=128,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def faiss_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
|
||||
|
|
@ -226,7 +178,7 @@ async def faiss_vec_index(embedding_dimension):
|
|||
@pytest.fixture
|
||||
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = FaissVectorIOConfig(
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
adapter = FaissVectorIOAdapter(
|
||||
config=config,
|
||||
|
|
@ -234,8 +186,8 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
await adapter.register_vector_store(
|
||||
VectorStore(
|
||||
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
|
|
@ -246,98 +198,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
|
|||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chroma_vec_db_path(tmp_path_factory):
|
||||
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
|
||||
return str(persist_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
|
||||
client = PersistentClient(path=chroma_vec_db_path)
|
||||
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
|
||||
collection = await maybe_await(client.get_or_create_collection(name))
|
||||
index = ChromaIndex(client=client, collection=collection)
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
config = ChromaVectorIOConfig(
|
||||
db_path=chroma_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = ChromaVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_vec_db_path(tmp_path_factory):
|
||||
import uuid
|
||||
|
||||
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
config = QdrantVectorIOConfig(
|
||||
db_path=qdrant_vec_db_path,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = QdrantVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
|
||||
import uuid
|
||||
|
||||
from qdrant_client import AsyncQdrantClient
|
||||
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
|
||||
|
||||
client = AsyncQdrantClient(path=qdrant_vec_db_path)
|
||||
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
|
||||
index = QdrantIndex(client, collection_name)
|
||||
yield index
|
||||
await index.delete()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_psycopg2_connection():
|
||||
connection = MagicMock()
|
||||
|
|
@ -355,7 +215,7 @@ def mock_psycopg2_connection():
|
|||
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
vector_store = VectorStore(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
|
|
@ -365,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
|
|||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
|
||||
index._test_chunks = []
|
||||
original_add_chunks = index.add_chunks
|
||||
|
||||
|
|
@ -393,7 +253,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
db="test_db",
|
||||
user="test_user",
|
||||
password="test_password",
|
||||
kvstore=unique_kvstore_config,
|
||||
persistence=unique_kvstore_config,
|
||||
)
|
||||
|
||||
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
|
||||
|
|
@ -421,110 +281,41 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
await adapter.initialize()
|
||||
adapter.conn = mock_conn
|
||||
|
||||
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
await index.insert_chunks(chunks)
|
||||
|
||||
adapter.insert_chunks = mock_insert_chunks
|
||||
|
||||
async def mock_query_chunks(vector_db_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
|
||||
async def mock_query_chunks(vector_store_id, query, params=None):
|
||||
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
|
||||
if not index:
|
||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||
raise ValueError(f"Vector DB {vector_store_id} not found")
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
adapter.query_chunks = mock_query_chunks
|
||||
|
||||
test_vector_db = VectorDB(
|
||||
test_vector_store = VectorStore(
|
||||
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
await adapter.register_vector_db(test_vector_db)
|
||||
adapter.test_collection_id = test_vector_db.identifier
|
||||
await adapter.register_vector_store(test_vector_store)
|
||||
adapter.test_collection_id = test_vector_store.identifier
|
||||
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def weaviate_vec_db_path(tmp_path_factory):
|
||||
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
|
||||
return db_path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_index(weaviate_vec_db_path):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
index = WeaviateIndex(client=client, collection_name="Testcollection")
|
||||
await index.initialize()
|
||||
yield index
|
||||
await index.delete()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
|
||||
import pytest_socket
|
||||
import weaviate
|
||||
|
||||
pytest_socket.enable_socket()
|
||||
|
||||
client = weaviate.connect_to_embedded(
|
||||
hostname="localhost",
|
||||
port=8080,
|
||||
grpc_port=50051,
|
||||
persistence_data_path=weaviate_vec_db_path,
|
||||
)
|
||||
|
||||
config = WeaviateVectorIOConfig(
|
||||
weaviate_cluster_url="localhost:8080",
|
||||
weaviate_api_key=None,
|
||||
kvstore=unique_kvstore_config,
|
||||
)
|
||||
adapter = WeaviateVectorIOAdapter(
|
||||
config=config,
|
||||
inference_api=mock_inference_api,
|
||||
files_api=None,
|
||||
)
|
||||
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
|
||||
await adapter.initialize()
|
||||
await adapter.register_vector_db(
|
||||
VectorDB(
|
||||
identifier=collection_id,
|
||||
provider_id="test_provider",
|
||||
embedding_model="test_model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
)
|
||||
)
|
||||
adapter.test_collection_id = collection_id
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
client.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_io_adapter(vector_provider, request):
|
||||
vector_provider_dict = {
|
||||
"milvus": "milvus_vec_adapter",
|
||||
"faiss": "faiss_vec_adapter",
|
||||
"sqlite_vec": "sqlite_vec_adapter",
|
||||
"chroma": "chroma_vec_adapter",
|
||||
"qdrant": "qdrant_vec_adapter",
|
||||
"pgvector": "pgvector_vec_adapter",
|
||||
"weaviate": "weaviate_vec_adapter",
|
||||
}
|
||||
return request.getfixturevalue(vector_provider_dict[vector_provider])
|
||||
|
||||
|
|
|
|||
|
|
@ -1,326 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
|
||||
# Mock the entire pymilvus module
|
||||
pymilvus_mock = MagicMock()
|
||||
pymilvus_mock.DataType = MagicMock()
|
||||
pymilvus_mock.MilvusClient = MagicMock
|
||||
pymilvus_mock.RRFRanker = MagicMock
|
||||
pymilvus_mock.WeightedRanker = MagicMock
|
||||
pymilvus_mock.AnnSearchRequest = MagicMock
|
||||
|
||||
# Apply the mock before importing MilvusIndex
|
||||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||
|
||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
MILVUS_PROVIDER = "milvus"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_milvus_client() -> MagicMock:
|
||||
"""Create a mock Milvus client with common method behaviors."""
|
||||
client = MagicMock()
|
||||
|
||||
# Mock collection operations
|
||||
client.has_collection.return_value = False # Initially no collection
|
||||
client.create_collection.return_value = None
|
||||
client.drop_collection.return_value = None
|
||||
|
||||
# Mock insert operation
|
||||
client.insert.return_value = {"insert_count": 10}
|
||||
|
||||
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||
client.search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||
client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||
"score": 0.9,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||
"score": 0.8,
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk3",
|
||||
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||
"score": 0.7,
|
||||
},
|
||||
]
|
||||
|
||||
return client
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def milvus_index(mock_milvus_client):
|
||||
"""Create a MilvusIndex with mocked client."""
|
||||
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||
yield index
|
||||
# No real cleanup needed since we're using mocks
|
||||
|
||||
|
||||
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
# Setup: collection doesn't exist initially, then exists after creation
|
||||
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Verify collection was created and data was inserted
|
||||
mock_milvus_client.create_collection.assert_called_once()
|
||||
mock_milvus_client.insert.assert_called_once()
|
||||
|
||||
# Verify the insert call had the right number of chunks
|
||||
insert_call = mock_milvus_client.insert.call_args
|
||||
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||
|
||||
|
||||
async def test_query_chunks_vector(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
# Setup: Add chunks first
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test vector search
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
mock_milvus_client.search.assert_called_once()
|
||||
|
||||
|
||||
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Test keyword search
|
||||
query_string = "Sentence 5"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
|
||||
|
||||
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Force BM25 search to fail
|
||||
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||
|
||||
# Mock simple text search results
|
||||
mock_milvus_client.query.return_value = [
|
||||
{
|
||||
"chunk_id": "chunk1",
|
||||
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||
},
|
||||
{
|
||||
"chunk_id": "chunk2",
|
||||
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||
},
|
||||
]
|
||||
|
||||
# Test keyword search that should fall back to simple text search
|
||||
query_string = "Python"
|
||||
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||
|
||||
# Verify response structure
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||
|
||||
# Verify that simple text search was used (query method called instead of search)
|
||||
mock_milvus_client.query.assert_called_once()
|
||||
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||
|
||||
# Verify the query uses parameterized filter with filter_params
|
||||
query_call_args = mock_milvus_client.query.call_args
|
||||
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||
|
||||
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||
|
||||
|
||||
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||
# Test collection deletion
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
|
||||
await milvus_index.delete()
|
||||
|
||||
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||
|
||||
|
||||
async def test_query_hybrid_search_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with RRF reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with RRF reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="rrf",
|
||||
reranker_params={"impact_factor": 60.0},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
|
||||
# Check that the request contains both vector and BM25 search requests
|
||||
reqs = call_args[1]["reqs"]
|
||||
assert len(reqs) == 2
|
||||
assert reqs[0].anns_field == "vector"
|
||||
assert reqs[1].anns_field == "sparse"
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_weighted(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with weighted reranker."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
{
|
||||
"id": 1,
|
||||
"distance": 0.2,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with weighted reranker
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=2,
|
||||
score_threshold=0.0,
|
||||
reranker_type="weighted",
|
||||
reranker_params={"alpha": 0.7},
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 2
|
||||
assert len(response.scores) == 2
|
||||
|
||||
# Verify hybrid search was called with correct parameters
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
async def test_query_hybrid_search_default_rrf(
|
||||
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||
):
|
||||
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
|
||||
mock_milvus_client.has_collection.return_value = True
|
||||
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||
|
||||
# Mock hybrid search results
|
||||
mock_milvus_client.hybrid_search.return_value = [
|
||||
[
|
||||
{
|
||||
"id": 0,
|
||||
"distance": 0.1,
|
||||
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||
},
|
||||
]
|
||||
]
|
||||
|
||||
# Test hybrid search with default reranker (should be RRF)
|
||||
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||
query_string = "test query"
|
||||
response = await milvus_index.query_hybrid(
|
||||
embedding=query_embedding,
|
||||
query_string=query_string,
|
||||
k=1,
|
||||
score_threshold=0.0,
|
||||
reranker_type="unknown_type", # Should default to RRF
|
||||
reranker_params=None, # Should use default impact_factor
|
||||
)
|
||||
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == 1
|
||||
|
||||
# Verify hybrid search was called with RRF reranker
|
||||
mock_milvus_client.hybrid_search.assert_called_once()
|
||||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
|
@ -1,138 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
|
||||
|
||||
PGVECTOR_PROVIDER = "pgvector"
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def embedding_dimension():
|
||||
"""Default embedding dimension for tests."""
|
||||
return 384
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
|
||||
"""Create a PGVectorIndex instance with mocked database connection."""
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
# Use explicit COSINE distance metric for consistent testing
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
|
||||
return index, cursor
|
||||
|
||||
|
||||
class TestPGVectorIndex:
|
||||
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
|
||||
assert index.distance_metric == "L2"
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
|
||||
|
||||
def test_get_pgvector_search_function(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric, function in supported_metrics.items():
|
||||
index.distance_metric = metric
|
||||
assert index.get_pgvector_search_function() == function
|
||||
|
||||
def test_check_distance_metric_availability(self, pgvector_index):
|
||||
index, cursor = pgvector_index
|
||||
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
|
||||
|
||||
for metric in supported_metrics:
|
||||
index.check_distance_metric_availability(metric)
|
||||
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
|
||||
index.check_distance_metric_availability("INVALID")
|
||||
|
||||
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
|
||||
|
||||
with pytest.raises(ValueError, match="Supported metrics are:"):
|
||||
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
|
||||
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
|
||||
assert index.distance_metric == "COSINE"
|
||||
except ValueError:
|
||||
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
|
||||
|
||||
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
|
||||
connection, cursor = mock_psycopg2_connection
|
||||
|
||||
vector_db = VectorDB(
|
||||
identifier="test-vector-db",
|
||||
embedding_model="test-model",
|
||||
embedding_dimension=embedding_dimension,
|
||||
provider_id=PGVECTOR_PROVIDER,
|
||||
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
|
||||
)
|
||||
|
||||
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
|
||||
|
||||
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
|
||||
for metric in supported_metrics:
|
||||
try:
|
||||
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
|
||||
assert index.distance_metric == metric
|
||||
|
||||
expected_operators = {
|
||||
"L2": "<->",
|
||||
"L1": "<+>",
|
||||
"COSINE": "<=>",
|
||||
"INNER_PRODUCT": "<#>",
|
||||
"HAMMING": "<~>",
|
||||
"JACCARD": "<%>",
|
||||
}
|
||||
assert index.get_pgvector_search_function() == expected_operators[metric]
|
||||
except Exception as e:
|
||||
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")
|
||||
|
|
@ -11,8 +11,8 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.datatypes import HealthStatus
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import (
|
||||
|
|
@ -39,12 +39,12 @@ def loop():
|
|||
|
||||
@pytest.fixture
|
||||
def embedding_dimension():
|
||||
return 384
|
||||
return 768
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def vector_db_id():
|
||||
return "test_vector_db"
|
||||
def vector_store_id():
|
||||
return "test_vector_store"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -61,12 +61,12 @@ def sample_embeddings(embedding_dimension):
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "mock_embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = embedding_dimension
|
||||
return mock_vector_db
|
||||
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
|
||||
mock_vector_store = MagicMock(spec=VectorStore)
|
||||
mock_vector_store.embedding_model = "mock_embedding_model"
|
||||
mock_vector_store.identifier = vector_store_id
|
||||
mock_vector_store.embedding_dimension = embedding_dimension
|
||||
return mock_vector_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,147 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inference.inference import OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage
|
||||
from llama_stack.apis.vector_io import (
|
||||
QueryChunksResponse,
|
||||
VectorDB,
|
||||
VectorDBStore,
|
||||
)
|
||||
from llama_stack.providers.inline.vector_io.qdrant.config import (
|
||||
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
|
||||
QdrantVectorIOAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/test_qdrant.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
|
||||
kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
|
||||
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def loop():
|
||||
return asyncio.new_event_loop()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db(vector_db_id) -> MagicMock:
|
||||
mock_vector_db = MagicMock(spec=VectorDB)
|
||||
mock_vector_db.embedding_model = "embedding_model"
|
||||
mock_vector_db.identifier = vector_db_id
|
||||
mock_vector_db.embedding_dimension = 384
|
||||
mock_vector_db.model_dump_json.return_value = (
|
||||
'{"identifier": "'
|
||||
+ vector_db_id
|
||||
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
|
||||
)
|
||||
return mock_vector_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_vector_db_store(mock_vector_db) -> MagicMock:
|
||||
mock_store = MagicMock(spec=VectorDBStore)
|
||||
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
|
||||
return mock_store
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_api_service(sample_embeddings):
|
||||
mock_api_service = MagicMock(spec=Inference)
|
||||
mock_api_service.openai_embeddings = AsyncMock(
|
||||
return_value=OpenAIEmbeddingsResponse(
|
||||
model="mock-embedding-model",
|
||||
data=[OpenAIEmbeddingData(embedding=sample, index=i) for i, sample in enumerate(sample_embeddings)],
|
||||
usage=OpenAIEmbeddingUsage(prompt_tokens=10, total_tokens=10),
|
||||
)
|
||||
)
|
||||
return mock_api_service
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
|
||||
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
|
||||
adapter.vector_db_store = mock_vector_db_store
|
||||
await adapter.initialize()
|
||||
yield adapter
|
||||
await adapter.shutdown()
|
||||
|
||||
|
||||
__QUERY = "Sample query"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
|
||||
async def test_qdrant_adapter_returns_expected_chunks(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
vector_db_id,
|
||||
sample_chunks,
|
||||
sample_embeddings,
|
||||
max_query_chunks,
|
||||
expected_chunks,
|
||||
) -> None:
|
||||
assert qdrant_adapter is not None
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
|
||||
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
|
||||
assert index is not None
|
||||
|
||||
response = await qdrant_adapter.query_chunks(
|
||||
query=__QUERY,
|
||||
vector_db_id=vector_db_id,
|
||||
params={"max_chunks": max_query_chunks, "mode": "vector"},
|
||||
)
|
||||
assert isinstance(response, QueryChunksResponse)
|
||||
assert len(response.chunks) == expected_chunks
|
||||
|
||||
|
||||
# To by-pass attempt to convert a Mock to JSON
|
||||
def _prepare_for_json(value: Any) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
|
||||
async def test_qdrant_register_and_unregister_vector_db(
|
||||
qdrant_adapter: QdrantVectorIOAdapter,
|
||||
mock_vector_db,
|
||||
sample_chunks,
|
||||
) -> None:
|
||||
# Initially, no collections
|
||||
vector_db_id = mock_vector_db.identifier
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
||||
# Register does not create a collection
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
await qdrant_adapter.register_vector_db(mock_vector_db)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
|
||||
# First insert creates the collection
|
||||
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
|
||||
assert await qdrant_adapter.client.collection_exists(vector_db_id)
|
||||
|
||||
# Unregister deletes the collection
|
||||
await qdrant_adapter.unregister_vector_db(vector_db_id)
|
||||
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
|
||||
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
|
||||
|
|
@ -12,14 +12,16 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.common.errors import VectorStoreNotFoundError
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
||||
QueryChunksResponse,
|
||||
VectorStoreChunkingStrategyAuto,
|
||||
VectorStoreFileObject,
|
||||
)
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
|
||||
from llama_stack.apis.vector_stores import VectorStore
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
|
||||
|
||||
# This test is a unit test for the inline VectorIO providers. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
|
|
@ -69,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
|
|||
|
||||
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
||||
key = f"{VECTOR_DBS_PREFIX}db1"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
|
||||
|
|
@ -79,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
|
|||
|
||||
async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
||||
await vector_io_adapter.initialize()
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
await vector_io_adapter.shutdown()
|
||||
|
||||
await vector_io_adapter.initialize()
|
||||
|
|
@ -90,26 +92,22 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
|
|||
await vector_io_adapter.shutdown()
|
||||
|
||||
|
||||
async def test_register_and_unregister_vector_db(vector_io_adapter):
|
||||
async def test_register_and_unregister_vector_store(vector_io_adapter):
|
||||
unique_id = f"foo_db_{np.random.randint(1e6)}"
|
||||
dummy = VectorDB(
|
||||
dummy = VectorStore(
|
||||
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
|
||||
)
|
||||
|
||||
await vector_io_adapter.register_vector_db(dummy)
|
||||
await vector_io_adapter.register_vector_store(dummy)
|
||||
assert dummy.identifier in vector_io_adapter.cache
|
||||
await vector_io_adapter.unregister_vector_db(dummy.identifier)
|
||||
await vector_io_adapter.unregister_vector_store(dummy.identifier)
|
||||
assert dummy.identifier not in vector_io_adapter.cache
|
||||
|
||||
|
||||
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
|
||||
fake_emb = np.zeros(8, dtype=np.float32)
|
||||
if vector_provider == "chroma":
|
||||
with pytest.raises(AttributeError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
else:
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
|
||||
|
||||
|
||||
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
||||
|
|
@ -123,12 +121,43 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
|
|||
|
||||
|
||||
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.insert_chunks("db_not_exist", [])
|
||||
|
||||
|
||||
async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
|
||||
"""Ensure no KeyError when document_id is missing or in different places."""
|
||||
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
|
||||
|
||||
fake_index = AsyncMock()
|
||||
vector_io_adapter.cache["db1"] = fake_index
|
||||
|
||||
# Various document_id scenarios that shouldn't crash
|
||||
chunks = [
|
||||
Chunk(content="has doc_id in metadata", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="no doc_id anywhere", metadata={"source": "test"}),
|
||||
Chunk(content="doc_id in chunk_metadata", chunk_metadata=ChunkMetadata(document_id="doc-3")),
|
||||
]
|
||||
|
||||
# Should work without KeyError
|
||||
await vector_io_adapter.insert_chunks("db1", chunks)
|
||||
fake_index.insert_chunks.assert_awaited_once()
|
||||
|
||||
|
||||
async def test_document_id_with_invalid_type_raises_error():
|
||||
"""Ensure TypeError is raised when document_id is not a string."""
|
||||
from llama_stack.apis.vector_io import Chunk
|
||||
|
||||
# Integer document_id should raise TypeError
|
||||
chunk = Chunk(content="test", metadata={"document_id": 12345})
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
_ = chunk.document_id
|
||||
assert "metadata['document_id'] must be a string" in str(exc_info.value)
|
||||
assert "got int" in str(exc_info.value)
|
||||
|
||||
|
||||
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||
|
|
@ -141,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
|
|||
|
||||
|
||||
async def test_query_chunks_missing_db_raises(vector_io_adapter):
|
||||
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
|
||||
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await vector_io_adapter.query_chunks("db_missing", "q", None)
|
||||
|
|
@ -153,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -169,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -185,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -200,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
|
|||
"id": store_id,
|
||||
"name": "Test Store",
|
||||
"description": "A test OpenAI vector store",
|
||||
"vector_db_id": "test_db",
|
||||
"vector_store_id": "test_db",
|
||||
"embedding_model": "test_model",
|
||||
}
|
||||
|
||||
|
|
@ -330,8 +359,7 @@ async def test_create_vector_store_file_batch(vector_io_adapter):
|
|||
vector_io_adapter._process_file_batch_async = AsyncMock()
|
||||
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
assert batch.vector_store_id == store_id
|
||||
|
|
@ -358,8 +386,7 @@ async def test_retrieve_vector_store_file_batch(vector_io_adapter):
|
|||
|
||||
# Create batch first
|
||||
created_batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Retrieve batch
|
||||
|
|
@ -392,8 +419,7 @@ async def test_cancel_vector_store_file_batch(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Cancel batch
|
||||
|
|
@ -438,8 +464,7 @@ async def test_list_files_in_vector_store_file_batch(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# List files
|
||||
|
|
@ -459,7 +484,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
|
|||
with pytest.raises(VectorStoreNotFoundError):
|
||||
await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id="nonexistent",
|
||||
file_ids=["file_1"],
|
||||
params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"]),
|
||||
)
|
||||
|
||||
# Setup store for remaining tests
|
||||
|
|
@ -476,8 +501,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
|
|||
# Test wrong vector store for batch
|
||||
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=["file_1"],
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
|
||||
)
|
||||
|
||||
# Create wrong_store so it exists but the batch doesn't belong to it
|
||||
|
|
@ -524,8 +548,7 @@ async def test_file_batch_pagination(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Test pagination with limit
|
||||
|
|
@ -597,8 +620,7 @@ async def test_file_batch_status_filtering(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Test filtering by completed status
|
||||
|
|
@ -640,8 +662,7 @@ async def test_cancel_completed_batch_fails(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Manually update status to completed
|
||||
|
|
@ -675,8 +696,7 @@ async def test_file_batch_persistence_across_restarts(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
|
|
@ -731,8 +751,7 @@ async def test_cancelled_batch_persists_in_storage(vector_io_adapter):
|
|||
|
||||
# Create batch
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
batch_id = batch.id
|
||||
|
||||
|
|
@ -779,10 +798,10 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
|
|||
|
||||
# Create multiple batches
|
||||
batch1 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, file_ids=["file_1"]
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
|
||||
)
|
||||
batch2 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, file_ids=["file_2"]
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_2"])
|
||||
)
|
||||
|
||||
# Complete one batch (should persist with completed status)
|
||||
|
|
@ -795,7 +814,7 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
|
|||
|
||||
# Create a third batch that stays in progress
|
||||
batch3 = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id, file_ids=["file_3"]
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_3"])
|
||||
)
|
||||
|
||||
# Simulate restart - clear memory and reload from persistence
|
||||
|
|
@ -956,8 +975,7 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
|||
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
|
||||
|
||||
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
|
||||
vector_store_id=store_id,
|
||||
file_ids=file_ids,
|
||||
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
|
||||
)
|
||||
|
||||
# Give time for the semaphore logic to start processing files
|
||||
|
|
@ -975,3 +993,130 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
|
|||
assert batch.status == "in_progress"
|
||||
assert batch.file_counts.total == 8
|
||||
assert batch.file_counts.in_progress == 8
|
||||
|
||||
|
||||
async def test_embedding_config_from_metadata(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from metadata."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with embedding config in metadata
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={
|
||||
"embedding_model": "test-embedding-model",
|
||||
"embedding_dimension": "512",
|
||||
},
|
||||
model_extra={},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorStore was registered with correct embedding config from metadata
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "test-embedding-model"
|
||||
assert call_args.embedding_dimension == 512
|
||||
|
||||
|
||||
async def test_embedding_config_from_extra_body(vector_io_adapter):
|
||||
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with embedding config in extra_body only (metadata has no embedding_model)
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={}, # Empty metadata to ensure extra_body is used
|
||||
**{
|
||||
"embedding_model": "extra-body-model",
|
||||
"embedding_dimension": 1024,
|
||||
},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Verify VectorStore was registered with correct embedding config from extra_body
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "extra-body-model"
|
||||
assert call_args.embedding_dimension == 1024
|
||||
|
||||
|
||||
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
|
||||
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with consistent embedding config in both metadata and extra_body
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={
|
||||
"embedding_model": "consistent-model",
|
||||
"embedding_dimension": "768",
|
||||
},
|
||||
**{
|
||||
"embedding_model": "consistent-model",
|
||||
"embedding_dimension": 768,
|
||||
},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should not raise any error and use metadata config
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "consistent-model"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
||||
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
|
||||
"""Test that embedding dimension defaults to 768 when not provided."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
|
||||
# Test with only embedding model, no dimension (metadata empty to use extra_body)
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(
|
||||
name="test_store",
|
||||
metadata={}, # Empty metadata to ensure extra_body is used
|
||||
**{
|
||||
"embedding_model": "model-without-dimension",
|
||||
},
|
||||
)
|
||||
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
||||
# Should default to 768 dimensions
|
||||
vector_io_adapter.register_vector_store.assert_called_once()
|
||||
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
|
||||
assert call_args.embedding_model == "model-without-dimension"
|
||||
assert call_args.embedding_dimension == 768
|
||||
|
||||
|
||||
async def test_embedding_config_required_model_missing(vector_io_adapter):
|
||||
"""Test that missing embedding model raises error."""
|
||||
|
||||
# Mock register_vector_store to avoid actual registration
|
||||
vector_io_adapter.register_vector_store = AsyncMock()
|
||||
# Set provider_id attribute for the adapter
|
||||
vector_io_adapter.__provider_id__ = "test_provider"
|
||||
# Mock the default model lookup to return None (no default model available)
|
||||
vector_io_adapter._get_default_embedding_model_and_dimension = AsyncMock(return_value=None)
|
||||
|
||||
# Test with no embedding model provided
|
||||
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
|
||||
|
||||
with pytest.raises(ValueError, match="embedding_model is required"):
|
||||
await vector_io_adapter.openai_create_vector_store(params)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue