Merge origin/main into add-missing-provider-data-impls

Resolved conflicts in:
- benchmarking/k8s-benchmark/stack_run_config.yaml (accepted new storage schema)
- llama_stack/providers/remote/inference/cerebras/cerebras.py (kept provider data support)
- llama_stack/providers/remote/inference/cerebras/config.py (kept provider data support)
- llama_stack/providers/remote/inference/nvidia/config.py (kept provider data support)
- llama_stack/providers/remote/inference/runpod/config.py (merged imports)
- pyproject.toml (kept databricks-sdk dependency)
This commit is contained in:
Ashwin Bharambe 2025-10-27 11:39:00 -07:00
commit 9eb9a37ee4
1880 changed files with 804868 additions and 70533 deletions

View file

@ -15,6 +15,7 @@ from llama_stack.apis.agents import (
AgentCreateResponse,
)
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import Inference
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolRuntime
@ -25,6 +26,20 @@ from llama_stack.providers.inline.agents.meta_reference.config import MetaRefere
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo
@pytest.fixture(autouse=True)
def setup_backends(tmp_path):
"""Register KV and SQL store backends for testing."""
from llama_stack.core.storage.datatypes import SqliteKVStoreConfig, SqliteSqlStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import register_kvstore_backends
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
kv_path = str(tmp_path / "test_kv.db")
sql_path = str(tmp_path / "test_sql.db")
register_kvstore_backends({"kv_default": SqliteKVStoreConfig(db_path=kv_path)})
register_sqlstore_backends({"sql_default": SqliteSqlStoreConfig(db_path=sql_path)})
@pytest.fixture
def mock_apis():
return {
@ -33,20 +48,26 @@ def mock_apis():
"safety_api": AsyncMock(spec=Safety),
"tool_runtime_api": AsyncMock(spec=ToolRuntime),
"tool_groups_api": AsyncMock(spec=ToolGroups),
"conversations_api": AsyncMock(spec=Conversations),
}
@pytest.fixture
def config(tmp_path):
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference.config import AgentPersistenceConfig
return MetaReferenceAgentsImplConfig(
persistence_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
responses_store={
"type": "sqlite",
"db_path": str(tmp_path / "test.db"),
},
persistence=AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
)
@ -59,7 +80,8 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
{},
mock_apis["conversations_api"],
[],
)
await impl.initialize()
yield impl

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, patch
import pytest
from openai.types.chat.chat_completion_chunk import (
@ -20,9 +20,11 @@ from llama_stack.apis.agents.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
@ -32,15 +34,16 @@ from llama_stack.apis.agents.openai_responses import (
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIDeveloperMessageParam,
OpenAIJSONSchema,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools.tools import ListToolDefsResponse, ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.core.datatypes import ResponsesStoreConfig
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@ -48,7 +51,7 @@ from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
from tests.unit.providers.agents.meta_reference.fixtures import load_chat_completion_fixture
@ -82,9 +85,28 @@ def mock_vector_io_api():
return vector_io_api
@pytest.fixture
def mock_conversations_api():
"""Mock conversations API for testing."""
mock_api = AsyncMock()
return mock_api
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
return safety_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api, mock_tool_groups_api, mock_tool_runtime_api, mock_responses_store, mock_vector_io_api
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_safety_api,
mock_conversations_api,
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
@ -92,6 +114,8 @@ def openai_responses_impl(
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
conversations_api=mock_conversations_api,
)
@ -147,18 +171,24 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,
tools=None,
stream=True,
temperature=0.1,
OpenAIChatCompletionRequestWithExtraBody(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=None,
tools=None,
stream=True,
temperature=0.1,
stream_options={
"include_usage": True,
},
)
)
# Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 4
# Expected: response.created, response.in_progress, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 5
assert chunks[0].type == "response.created"
assert any(chunk.type == "response.in_progress" for chunk in chunks)
# Check for content part events
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
@ -169,6 +199,14 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
assert len(text_delta_events) >= 1, "Should have text delta events"
added_event = content_part_added_events[0]
done_event = content_part_done_events[0]
assert added_event.content_index == 0
assert done_event.content_index == 0
assert added_event.output_index == done_event.output_index == 0
assert added_event.item_id == done_event.item_id
assert added_event.response_id == done_event.response_id
# Verify final event is completion
assert chunks[-1].type == "response.completed"
@ -177,6 +215,8 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert final_response.model == model
assert len(final_response.output) == 1
assert isinstance(final_response.output[0], OpenAIResponseMessage)
assert final_response.output[0].id == added_event.item_id
assert final_response.id == added_event.response_id
openai_responses_impl.responses_store.store_response_object.assert_called_once()
assert final_response.output[0].content[0].text == "Dublin"
@ -228,13 +268,15 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == "What is the capital of Ireland?"
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == "What is the capital of Ireland?"
assert first_params.tools is not None
assert first_params.temperature == 0.1
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
assert second_call.kwargs["messages"][-1].content == "Dublin"
assert second_call.kwargs["temperature"] == 0.1
second_params = second_call.args[0]
assert second_params.messages[-1].content == "Dublin"
assert second_params.temperature == 0.1
openai_responses_impl.tool_groups_api.get_tool.assert_called_once_with("web_search")
openai_responses_impl.tool_runtime_api.invoke_tool.assert_called_once_with(
@ -303,36 +345,42 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 6
# Should have: response.created, response.in_progress, output_item.added,
# function_call_arguments.delta, function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 7
event_types = [chunk.type for chunk in chunks]
assert event_types == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.delta",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.delta"
assert chunks[3].type == "response.function_call_arguments.done"
assert chunks[4].type == "response.output_item.done"
# Check response.completed event (should have the tool call)
assert chunks[5].type == "response.completed"
assert len(chunks[5].response.output) == 1
assert chunks[5].response.output[0].type == "function_call"
assert chunks[5].response.output[0].name == "get_weather"
completed_chunk = chunks[-1]
assert completed_chunk.type == "response.completed"
assert len(completed_chunk.response.output) == 1
assert completed_chunk.response.output[0].type == "function_call"
assert completed_chunk.response.output[0].name == "get_weather"
async def test_create_openai_response_with_tool_call_function_arguments_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a function that does not accept arguments, or arguments set to None when they are not mandatory."""
# Setup
"""Test creating an OpenAI response with tool calls that omit arguments."""
input_text = "What is the time right now?"
model = "meta-llama/Llama-3.1-8B-Instruct"
@ -359,9 +407,22 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
object="chat.completion.chunk",
)
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
def assert_common_expectations(chunks) -> None:
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.tools is not None
assert first_params.temperature == 0.1
assert len(chunks[0].response.output) == 0
completed_chunk = chunks[-1]
assert completed_chunk.type == "response.completed"
assert len(completed_chunk.response.output) == 1
assert completed_chunk.response.output[0].type == "function_call"
assert completed_chunk.response.output[0].name == "get_current_time"
assert completed_chunk.response.output[0].arguments == "{}"
# Function does not accept arguments
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
@ -369,46 +430,23 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={},
name="get_current_time", description="Get current time for system's timezone", parameters={}
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
assert [chunk.type for chunk in chunks] == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
assert_common_expectations(chunks)
# Function accepts optional arguments
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
@ -418,42 +456,47 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={
"timezone": "string",
},
parameters={"timezone": "string"},
)
],
)
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert [chunk.type for chunk in chunks] == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
assert_common_expectations(chunks)
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 5
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.done"
assert chunks[3].type == "response.output_item.done"
# Check response.completed event (should have the tool call with arguments set to "{}")
assert chunks[4].type == "response.completed"
assert len(chunks[4].response.output) == 1
assert chunks[4].response.output[0].type == "function_call"
assert chunks[4].response.output[0].name == "get_current_time"
assert chunks[4].response.output[0].arguments == "{}"
# Function accepts optional arguments with additional optional fields
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
stream=True,
temperature=0.1,
tools=[
OpenAIResponseInputToolFunction(
name="get_current_time",
description="Get current time for system's timezone",
parameters={"timezone": "string", "location": "string"},
)
],
)
chunks = [chunk async for chunk in result]
assert [chunk.type for chunk in chunks] == [
"response.created",
"response.in_progress",
"response.output_item.added",
"response.function_call_arguments.done",
"response.output_item.done",
"response.completed",
]
assert_common_expectations(chunks)
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
@ -485,7 +528,9 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
# Verify the the correct messages were sent to the inference API i.e.
# All of the responses message were convered to the chat completion message objects
inference_messages = mock_inference_api.openai_chat_completion.call_args_list[0].kwargs["messages"]
call_args = mock_inference_api.openai_chat_completion.call_args_list[0]
params = call_args.args[0]
inference_messages = params.messages
for i, m in enumerate(input_messages):
if isinstance(m.content, str):
assert inference_messages[i].content == m.content
@ -653,7 +698,8 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 2
@ -691,7 +737,8 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4 # 1 system + 3 input messages
@ -751,7 +798,8 @@ async def test_create_openai_response_with_instructions_and_previous_response(
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.kwargs["messages"]
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
assert len(sent_messages) == 4, sent_messages
@ -767,6 +815,69 @@ async def test_create_openai_response_with_instructions_and_previous_response(
assert sent_messages[3].content == "Which is the largest?"
async def test_create_openai_response_with_previous_response_instructions(
openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test prepending instructions and previous response with instructions."""
input_item_message = OpenAIResponseMessage(
id="123",
content="Name some towns in Ireland",
role="user",
)
response_output_message = OpenAIResponseMessage(
id="123",
content="Galway, Longford, Sligo",
status="completed",
role="assistant",
)
response = _OpenAIResponseObjectWithInputAndMessages(
created_at=1,
id="resp_123",
model="fake_model",
output=[response_output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
messages=[
OpenAIUserMessageParam(content="Name some towns in Ireland"),
OpenAIAssistantMessageParam(content="Galway, Longford, Sligo"),
],
instructions="You are a helpful assistant.",
)
mock_responses_store.get_response_object.return_value = response
model = "meta-llama/Llama-3.1-8B-Instruct"
instructions = "You are a geography expert. Provide concise answers."
mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Execute
await openai_responses_impl.create_openai_response(
input="Which is the largest?", model=model, instructions=instructions, previous_response_id="123"
)
# Verify
mock_inference_api.openai_chat_completion.assert_called_once()
call_args = mock_inference_api.openai_chat_completion.call_args
params = call_args.args[0]
sent_messages = params.messages
# Check that instructions were prepended as a system message
# and that the previous response instructions were not carried over
assert len(sent_messages) == 4, sent_messages
assert sent_messages[0].role == "system"
assert sent_messages[0].content == instructions
# Check the rest of the messages were converted correctly
assert sent_messages[1].role == "user"
assert sent_messages[1].content == "Name some towns in Ireland"
assert sent_messages[2].role == "assistant"
assert sent_messages[2].content == "Galway, Longford, Sligo"
assert sent_messages[3].role == "user"
assert sent_messages[3].content == "Which is the largest?"
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
# Setup
@ -807,8 +918,10 @@ async def test_responses_store_list_input_items_logic():
# Create mock store and response store
mock_sql_store = AsyncMock()
backend_name = "sql_responses_test"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path="mock_db_path")})
responses_store = ResponsesStore(
ResponsesStoreConfig(sql_store_config=SqliteSqlStoreConfig(db_path="mock_db_path")), policy=default_policy()
ResponsesStoreReference(backend=backend_name, table_name="responses"), policy=default_policy()
)
responses_store.sql_store = mock_sql_store
@ -953,6 +1066,58 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed"
@patch("llama_stack.providers.utils.tools.mcp.list_mcp_tools")
async def test_reuse_mcp_tool_list(
mock_list_mcp_tools, openai_responses_impl, mock_responses_store, mock_inference_api
):
"""Test that mcp_list_tools can be reused where appropriate."""
mock_inference_api.openai_chat_completion.return_value = fake_stream()
mock_list_mcp_tools.return_value = ListToolDefsResponse(
data=[ToolDef(name="test_tool", description="a test tool", input_schema={}, output_schema={})]
)
res1 = await openai_responses_impl.create_openai_response(
input="What is 2+2?",
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
args = mock_responses_store.store_response_object.call_args
data = args.kwargs["response_object"].model_dump()
data["input"] = [input_item.model_dump() for input_item in args.kwargs["input"]]
data["messages"] = [msg.model_dump() for msg in args.kwargs["messages"]]
stored = _OpenAIResponseObjectWithInputAndMessages(**data)
mock_responses_store.get_response_object.return_value = stored
res2 = await openai_responses_impl.create_openai_response(
previous_response_id=res1.id,
input="Now what is 3+3?",
model="meta-llama/Llama-3.1-8B-Instruct",
store=True,
tools=[
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
],
)
assert len(mock_inference_api.openai_chat_completion.call_args_list) == 2
second_call = mock_inference_api.openai_chat_completion.call_args_list[1]
second_params = second_call.args[0]
tools_seen = second_params.tools
assert len(tools_seen) == 1
assert tools_seen[0]["function"]["name"] == "test_tool"
assert tools_seen[0]["function"]["description"] == "a test tool"
assert mock_list_mcp_tools.call_count == 1
listings = [obj for obj in res2.output if obj.type == "mcp_list_tools"]
assert len(listings) == 1
assert listings[0].server_label == "alabel"
assert len(listings[0].tools) == 1
assert listings[0].tools[0].name == "test_tool"
@pytest.mark.parametrize(
"text_format, response_format",
[
@ -987,8 +1152,9 @@ async def test_create_openai_response_with_text_format(
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["response_format"] == response_format
first_params = first_call.args[0]
assert first_params.messages[0].content == input_text
assert first_params.response_format == response_format
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
@ -1004,3 +1170,75 @@ async def test_create_openai_response_with_invalid_text_format(openai_responses_
model=model,
text=OpenAIResponseText(format={"type": "invalid"}),
)
async def test_create_openai_response_with_output_types_as_input(
openai_responses_impl, mock_inference_api, mock_responses_store
):
"""Test that response outputs can be used as inputs in multi-turn conversations.
Before adding OpenAIResponseOutput types to OpenAIResponseInput,
creating a _OpenAIResponseObjectWithInputAndMessages with some output types
in the input field would fail with a Pydantic ValidationError.
This test simulates storing a response where the input contains output message
types (MCP calls, function calls), which happens in multi-turn conversations.
"""
model = "meta-llama/Llama-3.1-8B-Instruct"
# Mock the inference response
mock_inference_api.openai_chat_completion.return_value = fake_stream()
# Create a response with store=True to trigger the storage path
result = await openai_responses_impl.create_openai_response(
input="What's the weather?",
model=model,
stream=True,
temperature=0.1,
store=True,
)
# Consume the stream
_ = [chunk async for chunk in result]
# Verify store was called
assert mock_responses_store.store_response_object.called
# Get the stored data
store_call_args = mock_responses_store.store_response_object.call_args
stored_response = store_call_args.kwargs["response_object"]
# Now simulate a multi-turn conversation where outputs become inputs
input_with_output_types = [
OpenAIResponseMessage(role="user", content="What's the weather?", name=None),
# These output types need to be valid OpenAIResponseInput
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_123",
name="get_weather",
arguments='{"city": "Tokyo"}',
type="function_call",
),
OpenAIResponseOutputMessageMCPCall(
id="mcp_456",
type="mcp_call",
server_label="weather_server",
name="get_temperature",
arguments='{"location": "Tokyo"}',
output="25°C",
),
]
# This simulates storing a response in a multi-turn conversation
# where previous outputs are included in the input.
stored_with_outputs = _OpenAIResponseObjectWithInputAndMessages(
id=stored_response.id,
created_at=stored_response.created_at,
model=stored_response.model,
status=stored_response.status,
output=stored_response.output,
input=input_with_output_types, # This will trigger Pydantic validation
messages=None,
)
assert stored_with_outputs.input == input_with_output_types
assert len(stored_with_outputs.input) == 3

View file

@ -0,0 +1,249 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseMessage,
OpenAIResponseObject,
OpenAIResponseObjectStreamResponseCompleted,
OpenAIResponseObjectStreamResponseOutputItemDone,
OpenAIResponseOutputMessageContentOutputText,
)
from llama_stack.apis.common.errors import (
ConversationNotFoundError,
InvalidConversationIdError,
)
from llama_stack.apis.conversations.conversations import (
ConversationItemList,
)
# Import existing fixtures from the main responses test file
pytest_plugins = ["tests.unit.providers.agents.meta_reference.test_openai_responses"]
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
@pytest.fixture
def responses_impl_with_conversations(
mock_inference_api,
mock_tool_groups_api,
mock_tool_runtime_api,
mock_responses_store,
mock_vector_io_api,
mock_conversations_api,
mock_safety_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
tool_groups_api=mock_tool_groups_api,
tool_runtime_api=mock_tool_runtime_api,
responses_store=mock_responses_store,
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
)
class TestConversationValidation:
"""Test conversation ID validation logic."""
async def test_nonexistent_conversation_raises_error(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test that ConversationNotFoundError is raised for non-existent conversation."""
conv_id = "conv_nonexistent"
# Mock conversation not found
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError):
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation=conv_id, stream=False
)
class TestMessageSyncing:
"""Test message syncing to conversations."""
async def test_sync_response_to_conversation_simple(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test syncing simple response to conversation."""
conv_id = "conv_test123"
input_text = "What are the 5 Ds of dodgeball?"
# Output items (what the model generated)
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="The 5 Ds are: Dodge, Duck, Dip, Dive, and Dodge.", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_text, output_items)
# should call add_items with user input and assistant response
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
assert call_args[0][0] == conv_id # conversation_id
items = call_args[0][1] # conversation_items
assert len(items) == 2
# User message
assert items[0].type == "message"
assert items[0].role == "user"
assert items[0].content[0].type == "input_text"
assert items[0].content[0].text == input_text
# Assistant message
assert items[1].type == "message"
assert items[1].role == "assistant"
async def test_sync_response_to_conversation_api_error(
self, responses_impl_with_conversations, mock_conversations_api
):
mock_conversations_api.add_items.side_effect = Exception("API Error")
output_items = []
# matching the behavior of OpenAI here
with pytest.raises(Exception, match="API Error"):
await responses_impl_with_conversations._sync_response_to_conversation(
"conv_test123", "Hello", output_items
)
async def test_sync_with_list_input(self, responses_impl_with_conversations, mock_conversations_api):
"""Test syncing with list of input messages."""
conv_id = "conv_test123"
input_messages = [
OpenAIResponseMessage(role="user", content=[{"type": "input_text", "text": "First message"}]),
]
output_items = [
OpenAIResponseMessage(
id="msg_response",
content=[OpenAIResponseOutputMessageContentOutputText(text="Response", type="output_text")],
role="assistant",
status="completed",
type="message",
)
]
await responses_impl_with_conversations._sync_response_to_conversation(conv_id, input_messages, output_items)
mock_conversations_api.add_items.assert_called_once()
call_args = mock_conversations_api.add_items.call_args
items = call_args[0][1]
# Should have input message + output message
assert len(items) == 2
class TestIntegrationWorkflow:
"""Integration tests for the full conversation workflow."""
async def test_create_response_with_valid_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a valid conversation parameter."""
mock_conversations_api.list_items.return_value = ConversationItemList(
data=[], first_id=None, has_more=False, last_id=None, object="list"
)
async def mock_streaming_response(*args, **kwargs):
message_item = OpenAIResponseMessage(
id="msg_response",
content=[
OpenAIResponseOutputMessageContentOutputText(
text="Test response", type="output_text", annotations=[]
)
],
role="assistant",
status="completed",
type="message",
)
# Emit output_item.done event first (needed for conversation sync)
yield OpenAIResponseObjectStreamResponseOutputItemDone(
response_id="resp_test123",
item=message_item,
output_index=0,
sequence_number=1,
type="response.output_item.done",
)
# Then emit response.completed
mock_response = OpenAIResponseObject(
id="resp_test123",
created_at=1234567890,
model="test-model",
object="response",
output=[message_item],
status="completed",
)
yield OpenAIResponseObjectStreamResponseCompleted(response=mock_response, type="response.completed")
responses_impl_with_conversations._create_streaming_response = mock_streaming_response
input_text = "Hello, how are you?"
conversation_id = "conv_test123"
response = await responses_impl_with_conversations.create_openai_response(
input=input_text, model="test-model", conversation=conversation_id, stream=False
)
assert response is not None
assert response.id == "resp_test123"
# Note: conversation sync happens inside _create_streaming_response,
# which we're mocking here, so we can't test it in this unit test.
# The sync logic is tested separately in TestMessageSyncing.
async def test_create_response_with_invalid_conversation_id(self, responses_impl_with_conversations):
"""Test creating a response with an invalid conversation ID."""
with pytest.raises(InvalidConversationIdError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="invalid_id", stream=False
)
assert "Expected an ID that begins with 'conv_'" in str(exc_info.value)
async def test_create_response_with_nonexistent_conversation(
self, responses_impl_with_conversations, mock_conversations_api
):
"""Test creating a response with a non-existent conversation."""
mock_conversations_api.list_items.side_effect = ConversationNotFoundError("conv_nonexistent")
with pytest.raises(ConversationNotFoundError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="Hello", model="test-model", conversation="conv_nonexistent", stream=False
)
assert "not found" in str(exc_info.value)
async def test_conversation_and_previous_response_id(
self, responses_impl_with_conversations, mock_conversations_api, mock_responses_store
):
with pytest.raises(ValueError) as exc_info:
await responses_impl_with_conversations.create_openai_response(
input="test", model="test", conversation="conv_123", previous_response_id="resp_123"
)
assert "Mutually exclusive parameters" in str(exc_info.value)
assert "previous_response_id" in str(exc_info.value)
assert "conversation" in str(exc_info.value)

View file

@ -0,0 +1,183 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.agents.openai_responses import (
MCPListToolsTool,
OpenAIResponseInputToolFileSearch,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
OpenAIResponseInputToolWebSearch,
OpenAIResponseObject,
OpenAIResponseOutputMessageMCPListTools,
OpenAIResponseToolMCP,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ToolContext
class TestToolContext:
def test_no_tools(self):
tools = []
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 0
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_no_previous_tools(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="label", server_url="url"),
]
context = ToolContext(tools)
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="mymodel", output=[], status="")
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert len(context.previous_tools) == 0
assert len(context.previous_tool_listings) == 0
def test_reusable_server(self):
tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFileSearch(vector_store_ids=["fake"]),
OpenAIResponseToolMCP(server_label="alabel"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 1
assert context.tools_to_process[0].type == "file_search"
assert len(context.previous_tools) == 1
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
def test_multiple_reusable_servers(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 2
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert len(context.previous_tools) == 2
assert context.previous_tools["test_tool"].server_label == "alabel"
assert context.previous_tools["test_tool"].server_url == "aurl"
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 2
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "alabel"
assert len(context.previous_tool_listings[1].tools) == 1
assert context.previous_tool_listings[1].server_label == "anotherlabel"
def test_multiple_servers_only_one_reusable(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl"),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
)
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"
def test_mismatched_allowed_tools(self):
tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseInputToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseInputToolMCP(server_label="alabel", server_url="aurl", allowed_tools=["test_tool_2"]),
]
context = ToolContext(tools)
output = [
OpenAIResponseOutputMessageMCPListTools(
id="test1", server_label="alabel", tools=[MCPListToolsTool(name="test_tool_1", input_schema={})]
),
OpenAIResponseOutputMessageMCPListTools(
id="test2",
server_label="anotherlabel",
tools=[MCPListToolsTool(name="some_other_tool", input_schema={})],
),
]
previous_response = OpenAIResponseObject(created_at=1234, id="test", model="fake", output=output, status="")
previous_response.tools = [
OpenAIResponseInputToolFunction(name="fake", parameters=None),
OpenAIResponseToolMCP(server_label="anotherlabel", server_url="anotherurl"),
OpenAIResponseInputToolWebSearch(type="web_search"),
OpenAIResponseToolMCP(server_label="alabel", server_url="aurl"),
]
context.recover_tools_from_previous_response(previous_response)
assert len(context.tools_to_process) == 3
assert context.tools_to_process[0].type == "function"
assert context.tools_to_process[1].type == "web_search"
assert context.tools_to_process[2].type == "mcp"
assert len(context.previous_tools) == 1
assert context.previous_tools["some_other_tool"].server_label == "anotherlabel"
assert context.previous_tools["some_other_tool"].server_url == "anotherurl"
assert len(context.previous_tool_listings) == 1
assert len(context.previous_tool_listings[0].tools) == 1
assert context.previous_tool_listings[0].server_label == "anotherlabel"

View file

@ -0,0 +1,155 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.agents.agents import ResponseGuardrailSpec
from llama_stack.apis.safety import ModerationObject, ModerationObjectResults
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
extract_guardrail_ids,
run_guardrails,
)
@pytest.fixture
def mock_apis():
"""Create mock APIs for testing."""
return {
"inference_api": AsyncMock(),
"tool_groups_api": AsyncMock(),
"tool_runtime_api": AsyncMock(),
"responses_store": AsyncMock(),
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
}
@pytest.fixture
def responses_impl(mock_apis):
"""Create OpenAIResponsesImpl instance with mocked dependencies."""
return OpenAIResponsesImpl(**mock_apis)
def test_extract_guardrail_ids_from_strings(responses_impl):
"""Test extraction from simple string guardrail IDs."""
guardrails = ["llama-guard", "content-filter", "nsfw-detector"]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_from_objects(responses_impl):
"""Test extraction from ResponseGuardrailSpec objects."""
guardrails = [
ResponseGuardrailSpec(type="llama-guard"),
ResponseGuardrailSpec(type="content-filter"),
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter"]
def test_extract_guardrail_ids_mixed_formats(responses_impl):
"""Test extraction from mixed string and object formats."""
guardrails = [
"llama-guard",
ResponseGuardrailSpec(type="content-filter"),
"nsfw-detector",
]
result = extract_guardrail_ids(guardrails)
assert result == ["llama-guard", "content-filter", "nsfw-detector"]
def test_extract_guardrail_ids_none_input(responses_impl):
"""Test extraction with None input."""
result = extract_guardrail_ids(None)
assert result == []
def test_extract_guardrail_ids_empty_list(responses_impl):
"""Test extraction with empty list."""
result = extract_guardrail_ids([])
assert result == []
def test_extract_guardrail_ids_unknown_format(responses_impl):
"""Test extraction with unknown guardrail format raises ValueError."""
# Create an object that's neither string nor ResponseGuardrailSpec
unknown_object = {"invalid": "format"} # Plain dict, not ResponseGuardrailSpec
guardrails = ["valid-guardrail", unknown_object, "another-guardrail"]
with pytest.raises(ValueError, match="Unknown guardrail format.*expected str or ResponseGuardrailSpec"):
extract_guardrail_ids(guardrails)
@pytest.fixture
def mock_safety_api():
"""Create mock safety API for guardrails testing."""
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
return safety_api
async def test_run_guardrails_no_violation(mock_safety_api):
"""Test guardrails validation with no violations."""
text = "Hello world"
guardrail_ids = ["llama-guard"]
# Mock moderation to return non-flagged content
unflagged_result = ModerationObjectResults(flagged=False, categories={"violence": False})
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[unflagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result is None
# Verify run_moderation was called with the correct model
mock_safety_api.run_moderation.assert_called_once()
call_args = mock_safety_api.run_moderation.call_args
assert call_args[1]["model"] == "llama-guard-model"
async def test_run_guardrails_with_violation(mock_safety_api):
"""Test guardrails validation with safety violation."""
text = "Harmful content"
guardrail_ids = ["llama-guard"]
# Mock moderation to return flagged content
flagged_result = ModerationObjectResults(
flagged=True,
categories={"violence": True},
user_message="Content flagged by moderation",
metadata={"violation_type": ["S1"]},
)
mock_moderation_object = ModerationObject(id="test-mod-id", model="llama-guard-model", results=[flagged_result])
mock_safety_api.run_moderation.return_value = mock_moderation_object
result = await run_guardrails(mock_safety_api, text, guardrail_ids)
assert result == "Content flagged by moderation (flagged for: violence) (violation type: S1)"
async def test_run_guardrails_empty_inputs(mock_safety_api):
"""Test guardrails validation with empty inputs."""
# Test empty guardrail_ids
result = await run_guardrails(mock_safety_api, "test", [])
assert result is None
# Test empty text
result = await run_guardrails(mock_safety_api, "", ["llama-guard"])
assert result is None
# Test both empty
result = await run_guardrails(mock_safety_api, "", [])
assert result is None

View file

@ -12,10 +12,10 @@ from unittest.mock import AsyncMock
import pytest
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
@pytest.fixture
@ -23,8 +23,10 @@ async def provider():
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
backend_name = "kv_batches_test"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
register_kvstore_backends({backend_name: kvstore_config})
config = ReferenceBatchesImplConfig(kvstore=KVStoreReference(backend=backend_name, namespace="batches"))
# Create kvstore and mock APIs
kvstore = await kvstore_impl(config.kvstore)

View file

@ -213,7 +213,6 @@ class TestReferenceBatchesImpl:
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/invalid/endpoint",
"",
],
@ -765,3 +764,12 @@ class TestReferenceBatchesImpl:
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
async def test_create_batch_embeddings_endpoint(self, provider):
"""Test that batch creation succeeds with embeddings endpoint."""
batch = await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/embeddings",
completion_window="24h",
)
assert batch.endpoint == "/v1/embeddings"

View file

@ -8,8 +8,9 @@ import boto3
import pytest
from moto import mock_aws
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
class MockUploadFile:
@ -38,11 +39,13 @@ def sample_text_file2():
def s3_config(tmp_path):
db_path = tmp_path / "s3_files_metadata.db"
backend_name = f"sql_s3_{tmp_path.name}"
register_sqlstore_backends({backend_name: SqliteSqlStoreConfig(db_path=db_path.as_posix())})
return S3FilesImplConfig(
bucket_name=f"test-bucket-{tmp_path.name}",
region="not-a-region",
auto_create_bucket=True,
metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()),
metadata_store=SqlStoreReference(backend=backend_name, table_name="s3_files_metadata"),
)

View file

@ -15,16 +15,16 @@ from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOp
# Test fixtures and helper classes
class TestConfig(BaseModel):
class FakeConfig(BaseModel):
api_key: str | None = Field(default=None)
class TestProviderDataValidator(BaseModel):
class FakeProviderDataValidator(BaseModel):
test_api_key: str | None = Field(default=None)
class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: TestConfig):
class FakeLiteLLMAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: FakeConfig):
super().__init__(
litellm_provider_name="test",
api_key_from_config=config.api_key,
@ -36,11 +36,11 @@ class TestLiteLLMAdapter(LiteLLMOpenAIMixin):
@pytest.fixture
def adapter_with_config_key():
"""Fixture to create adapter with API key in config"""
config = TestConfig(api_key="config-api-key")
adapter = TestLiteLLMAdapter(config)
config = FakeConfig(api_key="config-api-key")
adapter = FakeLiteLLMAdapter(config)
adapter.__provider_spec__ = MagicMock()
adapter.__provider_spec__.provider_data_validator = (
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
)
return adapter
@ -48,11 +48,11 @@ def adapter_with_config_key():
@pytest.fixture
def adapter_without_config_key():
"""Fixture to create adapter without API key in config"""
config = TestConfig(api_key=None)
adapter = TestLiteLLMAdapter(config)
config = FakeConfig(api_key=None)
adapter = FakeLiteLLMAdapter(config)
adapter.__provider_spec__ = MagicMock()
adapter.__provider_spec__.provider_data_validator = (
"tests.unit.providers.inference.test_litellm_openai_mixin.TestProviderDataValidator"
"tests.unit.providers.inference.test_litellm_openai_mixin.FakeProviderDataValidator"
)
return adapter

View file

@ -13,10 +13,16 @@ import pytest
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletion,
OpenAIChatCompletionRequestWithExtraBody,
OpenAIChoice,
OpenAICompletion,
OpenAICompletionChoice,
OpenAICompletionRequestWithExtraBody,
ToolChoice,
)
from llama_stack.apis.models import Model
from llama_stack.core.routers.inference import InferenceRouter
from llama_stack.core.routing_tables.models import ModelsRoutingTable
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import VLLMInferenceAdapter
@ -56,13 +62,14 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
mock_client_property.return_value = mock_client
# No tools but auto tool choice
await vllm_inference_adapter.openai_chat_completion(
"mock-model",
[],
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
tools=None,
tool_choice=ToolChoice.auto.value,
)
await vllm_inference_adapter.openai_chat_completion(params)
mock_client.chat.completions.create.assert_called()
call_args = mock_client.chat.completions.create.call_args
# Ensure tool_choice gets converted to none for older vLLM versions
@ -171,9 +178,12 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
)
async def do_inference():
await vllm_inference_adapter.openai_chat_completion(
"mock-model", messages=["one fish", "two fish"], stream=False
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "one fish two fish"}],
stream=False,
)
await vllm_inference_adapter.openai_chat_completion(params)
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_create_client:
mock_client = MagicMock()
@ -186,3 +196,148 @@ async def test_openai_chat_completion_is_async(vllm_inference_adapter):
assert mock_create_client.call_count == 4 # no cheating
assert total_time < (sleep_time * 2), f"Total time taken: {total_time}s exceeded expected max"
async def test_vllm_completion_extra_body():
"""
Test that vLLM-specific guided_choice and prompt_logprobs parameters are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.completions.create = AsyncMock(
return_value=OpenAICompletion(
id="cmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAICompletionChoice(
text="joy",
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with guided_choice and prompt_logprobs as extra fields
params = OpenAICompletionRequestWithExtraBody(
model="mock-model",
prompt="I am feeling happy",
stream=False,
guided_choice=["joy", "sadness"],
prompt_logprobs=5,
)
await router.openai_completion(params)
# Verify that the client was called with extra_body containing both parameters
mock_client.completions.create.assert_called_once()
call_kwargs = mock_client.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "guided_choice" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["guided_choice"] == ["joy", "sadness"]
assert "prompt_logprobs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["prompt_logprobs"] == 5
async def test_vllm_chat_completion_extra_body():
"""
Test that vLLM-specific parameters (e.g., chat_template_kwargs) are correctly forwarded
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
# Create a mock model store
mock_model_store = AsyncMock()
mock_model = Model(identifier="mock-model", provider_resource_id="mock-model", provider_id="vllm")
mock_model_store.get_model.return_value = mock_model
mock_model_store.has_model.return_value = True
# Create a mock dist_registry
mock_dist_registry = MagicMock()
mock_dist_registry.get = AsyncMock(return_value=mock_model)
mock_dist_registry.set = AsyncMock()
# Set up the routing table
routing_table = ModelsRoutingTable(
impls_by_provider_id={"vllm": vllm_adapter},
dist_registry=mock_dist_registry,
policy=[],
)
# Inject the model store into the adapter
vllm_adapter.model_store = routing_table
# Create the InferenceRouter
router = InferenceRouter(routing_table=routing_table)
# Patch the OpenAI client
with patch.object(VLLMInferenceAdapter, "client", new_callable=PropertyMock) as mock_client_property:
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(
return_value=OpenAIChatCompletion(
id="chatcmpl-abc123",
created=1,
model="mock-model",
choices=[
OpenAIChoice(
message=OpenAIAssistantMessageParam(
content="test response",
),
finish_reason="stop",
index=0,
)
],
)
)
mock_client_property.return_value = mock_client
# Test with chat_template_kwargs as extra field
params = OpenAIChatCompletionRequestWithExtraBody(
model="mock-model",
messages=[{"role": "user", "content": "test"}],
stream=False,
chat_template_kwargs={"thinking": True},
)
await router.openai_chat_completion(params)
# Verify that the client was called with extra_body containing chat_template_kwargs
mock_client.chat.completions.create.assert_called_once()
call_kwargs = mock_client.chat.completions.create.call_args.kwargs
assert "extra_body" in call_kwargs
assert "chat_template_kwargs" in call_kwargs["extra_body"]
assert call_kwargs["extra_body"]["chat_template_kwargs"] == {"thinking": True}

View file

@ -4,10 +4,45 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.inline.agents.meta_reference.responses.streaming import (
convert_tooldef_to_chat_tool,
)
from llama_stack.providers.inline.agents.meta_reference.responses.types import ChatCompletionContext
@pytest.fixture
def mock_safety_api():
safety_api = AsyncMock()
# Mock the routing table and shields list for guardrails lookup
safety_api.routing_table = AsyncMock()
shield = AsyncMock()
shield.identifier = "llama-guard"
shield.provider_resource_id = "llama-guard-model"
safety_api.routing_table.list_shields.return_value = AsyncMock(data=[shield])
# Mock run_moderation to return non-flagged result by default
safety_api.run_moderation.return_value = AsyncMock(flagged=False)
return safety_api
@pytest.fixture
def mock_inference_api():
inference_api = AsyncMock()
return inference_api
@pytest.fixture
def mock_context():
context = AsyncMock(spec=ChatCompletionContext)
# Add required attributes that StreamingResponseOrchestrator expects
context.tool_context = AsyncMock()
context.tool_context.previous_tools = {}
context.messages = []
return context
def test_convert_tooldef_to_chat_tool_preserves_items_field():

View file

@ -19,7 +19,7 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
class TestNVIDIASafetyAdapter(NVIDIASafetyAdapter):
class FakeNVIDIASafetyAdapter(NVIDIASafetyAdapter):
"""Test implementation that provides the required shield_store."""
def __init__(self, config: NVIDIASafetyConfig, shield_store):
@ -41,7 +41,7 @@ def nvidia_adapter():
shield_store = AsyncMock()
shield_store.get_shield = AsyncMock()
adapter = TestNVIDIASafetyAdapter(config=config, shield_store=shield_store)
adapter = FakeNVIDIASafetyAdapter(config=config, shield_store=shield_store)
return adapter

View file

@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import pytest
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Model, OpenAIUserMessageParam
from llama_stack.apis.inference import Model, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.models import ModelType
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
@ -23,10 +23,10 @@ class OpenAIMixinImpl(OpenAIMixin):
__provider_id__: str = "test-provider"
def get_api_key(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
return "test-api-key"
def get_base_url(self) -> str:
raise NotImplementedError("This method should be mocked in tests")
return "http://test-base-url"
class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
@ -38,6 +38,28 @@ class OpenAIMixinWithEmbeddingsImpl(OpenAIMixinImpl):
}
class OpenAIMixinWithCustomModelConstruction(OpenAIMixinImpl):
"""Test implementation that uses construct_model_from_identifier to add rerank models"""
embedding_model_metadata: dict[str, dict[str, int]] = {
"text-embedding-3-small": {"embedding_dimension": 1536, "context_length": 8192},
"text-embedding-ada-002": {"embedding_dimension": 1536, "context_length": 8192},
}
# Adds rerank models via construct_model_from_identifier
rerank_model_ids: set[str] = {"rerank-model-1", "rerank-model-2"}
def construct_model_from_identifier(self, identifier: str) -> Model:
if identifier in self.rerank_model_ids:
return Model(
provider_id=self.__provider_id__, # type: ignore[attr-defined]
provider_resource_id=identifier,
identifier=identifier,
model_type=ModelType.rerank,
)
return super().construct_model_from_identifier(identifier)
@pytest.fixture
def mixin():
"""Create a test instance of OpenAIMixin with mocked model_store"""
@ -62,6 +84,13 @@ def mixin_with_embeddings():
return OpenAIMixinWithEmbeddingsImpl(config=config)
@pytest.fixture
def mixin_with_custom_model_construction():
"""Create a test instance using custom construct_model_from_identifier"""
config = RemoteInferenceProviderConfig()
return OpenAIMixinWithCustomModelConstruction(config=config)
@pytest.fixture
def mock_models():
"""Create multiple mock OpenAI model objects"""
@ -113,6 +142,19 @@ def mock_client_context():
return _mock_client_context
def _assert_models_match_expected(actual_models, expected_models):
"""Verify the models match expected attributes.
Args:
actual_models: List of models to verify
expected_models: Mapping of model identifier to expected attribute values
"""
for identifier, expected_attrs in expected_models.items():
model = next(m for m in actual_models if m.identifier == identifier)
for attr_name, expected_value in expected_attrs.items():
assert getattr(model, attr_name) == expected_value
class TestOpenAIMixinListModels:
"""Test cases for the list_models method"""
@ -205,7 +247,7 @@ class TestOpenAIMixinCheckModelAvailability:
assert await mixin.check_model_availability("pre-registered-model")
# Should not call the provider's list_models since model was found in store
mock_client_with_models.models.list.assert_not_called()
mock_model_store.has_model.assert_called_once_with("pre-registered-model")
mock_model_store.has_model.assert_called_once_with("test-provider/pre-registered-model")
async def test_check_model_availability_fallback_to_provider_when_not_in_store(
self, mixin, mock_client_with_models, mock_client_context
@ -222,7 +264,7 @@ class TestOpenAIMixinCheckModelAvailability:
assert await mixin.check_model_availability("some-mock-model-id")
# Should call the provider's list_models since model was not found in store
mock_client_with_models.models.list.assert_called_once()
mock_model_store.has_model.assert_called_once_with("some-mock-model-id")
mock_model_store.has_model.assert_called_once_with("test-provider/some-mock-model-id")
class TestOpenAIMixinCacheBehavior:
@ -271,7 +313,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
mock_localize.return_value = (b"fake_image_data", "jpeg")
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_called_once_with("http://example.com/image.jpg")
@ -303,7 +346,8 @@ class TestOpenAIMixinImagePreprocessing:
with patch.object(type(mixin), "client", new_callable=PropertyMock, return_value=mock_client):
with patch("llama_stack.providers.utils.inference.openai_mixin.localize_image_content") as mock_localize:
await mixin.openai_chat_completion(model="test-model", messages=[message])
params = OpenAIChatCompletionRequestWithExtraBody(model="test-model", messages=[message])
await mixin.openai_chat_completion(params)
mock_localize.assert_not_called()
@ -340,21 +384,71 @@ class TestOpenAIMixinEmbeddingModelMetadata:
assert result is not None
assert len(result) == 2
# Find the models in the result
embedding_model = next(m for m in result if m.identifier == "text-embedding-3-small")
llm_model = next(m for m in result if m.identifier == "gpt-4")
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
# Check embedding model
assert embedding_model.model_type == ModelType.embedding
assert embedding_model.metadata == {"embedding_dimension": 1536, "context_length": 8192}
assert embedding_model.provider_id == "test-provider"
assert embedding_model.provider_resource_id == "text-embedding-3-small"
_assert_models_match_expected(result, expected_models)
# Check LLM model
assert llm_model.model_type == ModelType.llm
assert llm_model.metadata == {} # No metadata for LLMs
assert llm_model.provider_id == "test-provider"
assert llm_model.provider_resource_id == "gpt-4"
class TestOpenAIMixinCustomModelConstruction:
"""Test cases for mixed model types (LLM, embedding, rerank) through construct_model_from_identifier"""
async def test_mixed_model_types_identification(self, mixin_with_custom_model_construction, mock_client_context):
"""Test that LLM, embedding, and rerank models are correctly identified with proper types and metadata"""
# Create mock models: 1 embedding, 1 rerank, 1 LLM
mock_embedding_model = MagicMock(id="text-embedding-3-small")
mock_rerank_model = MagicMock(id="rerank-model-1")
mock_llm_model = MagicMock(id="gpt-4")
mock_models = [mock_embedding_model, mock_rerank_model, mock_llm_model]
mock_client = MagicMock()
async def mock_models_list():
for model in mock_models:
yield model
mock_client.models.list.return_value = mock_models_list()
with mock_client_context(mixin_with_custom_model_construction, mock_client):
result = await mixin_with_custom_model_construction.list_models()
assert result is not None
assert len(result) == 3
expected_models = {
"text-embedding-3-small": {
"model_type": ModelType.embedding,
"metadata": {"embedding_dimension": 1536, "context_length": 8192},
"provider_id": "test-provider",
"provider_resource_id": "text-embedding-3-small",
},
"rerank-model-1": {
"model_type": ModelType.rerank,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "rerank-model-1",
},
"gpt-4": {
"model_type": ModelType.llm,
"metadata": {},
"provider_id": "test-provider",
"provider_resource_id": "gpt-4",
},
}
_assert_models_match_expected(result, expected_models)
class TestOpenAIMixinAllowedModels:
@ -720,7 +814,7 @@ class TestOpenAIMixinProviderDataApiKey:
):
"""Test that ValueError is raised when provider data exists but doesn't have required key"""
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"wrong_key": "some-value"})}):
with pytest.raises(ValueError, match="API key is not set"):
with pytest.raises(ValueError, match="API key not provided"):
_ = mixin_with_provider_data_field_and_none_api_key.client
def test_error_message_includes_correct_field_names(self, mixin_with_provider_data_field_and_none_api_key):

View file

@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.stack import replace_env_vars
from llama_stack.providers.remote.inference.anthropic import AnthropicConfig
from llama_stack.providers.remote.inference.azure import AzureConfig
from llama_stack.providers.remote.inference.bedrock import BedrockConfig
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.databricks import DatabricksImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.gemini import GeminiConfig
from llama_stack.providers.remote.inference.groq import GroqConfig
from llama_stack.providers.remote.inference.llama_openai_compat import LlamaCompatConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.providers.remote.inference.openai import OpenAIConfig
from llama_stack.providers.remote.inference.runpod import RunpodImplConfig
from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig
from llama_stack.providers.remote.inference.tgi import TGIImplConfig
from llama_stack.providers.remote.inference.together import TogetherImplConfig
from llama_stack.providers.remote.inference.vertexai import VertexAIConfig
from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
class TestRemoteInferenceProviderConfig:
@pytest.mark.parametrize(
"config_cls,alias_name,env_name,extra_config",
[
(AnthropicConfig, "api_key", "ANTHROPIC_API_KEY", {}),
(AzureConfig, "api_key", "AZURE_API_KEY", {"api_base": "HTTP://FAKE"}),
(BedrockConfig, None, None, {}),
(CerebrasImplConfig, "api_key", "CEREBRAS_API_KEY", {}),
(DatabricksImplConfig, "api_token", "DATABRICKS_TOKEN", {}),
(FireworksImplConfig, "api_key", "FIREWORKS_API_KEY", {}),
(GeminiConfig, "api_key", "GEMINI_API_KEY", {}),
(GroqConfig, "api_key", "GROQ_API_KEY", {}),
(LlamaCompatConfig, "api_key", "LLAMA_API_KEY", {}),
(NVIDIAConfig, "api_key", "NVIDIA_API_KEY", {}),
(OllamaImplConfig, None, None, {}),
(OpenAIConfig, "api_key", "OPENAI_API_KEY", {}),
(RunpodImplConfig, "api_token", "RUNPOD_API_TOKEN", {}),
(SambaNovaImplConfig, "api_key", "SAMBANOVA_API_KEY", {}),
(TGIImplConfig, None, None, {"url": "FAKE"}),
(TogetherImplConfig, "api_key", "TOGETHER_API_KEY", {}),
(VertexAIConfig, None, None, {"project": "FAKE", "location": "FAKE"}),
(VLLMInferenceAdapterConfig, "api_token", "VLLM_API_TOKEN", {}),
(WatsonXConfig, "api_key", "WATSONX_API_KEY", {}),
],
)
def test_provider_config_auth_credentials(self, monkeypatch, config_cls, alias_name, env_name, extra_config):
"""Test that the config class correctly maps the alias to auth_credential."""
secret_value = config_cls.__name__
if alias_name is None:
pytest.skip("No alias name provided for this config class.")
config = config_cls(**{alias_name: secret_value, **extra_config})
assert config.auth_credential is not None
assert config.auth_credential.get_secret_value() == secret_value
schema = config_cls.model_json_schema()
assert alias_name in schema["properties"]
assert "auth_credential" not in schema["properties"]
if env_name:
monkeypatch.setenv(env_name, secret_value)
sample_config = config_cls.sample_run_config()
expanded_config = replace_env_vars(sample_config)
config_from_sample = config_cls(**{**expanded_config, **extra_config})
assert config_from_sample.auth_credential is not None
assert config_from_sample.auth_credential.get_secret_value() == secret_value

View file

@ -9,38 +9,29 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np
import pytest
from chromadb import PersistentClient
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
from llama_stack.providers.inline.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.chroma.chroma import ChromaIndex, ChromaVectorIOAdapter, maybe_await
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
from llama_stack.providers.remote.vector_io.weaviate.config import WeaviateVectorIOConfig
from llama_stack.providers.remote.vector_io.weaviate.weaviate import WeaviateIndex, WeaviateVectorIOAdapter
from llama_stack.providers.utils.kvstore import register_kvstore_backends
EMBEDDING_DIMENSION = 384
EMBEDDING_DIMENSION = 768
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss", "chroma", "pgvector", "weaviate"])
@pytest.fixture(params=["sqlite_vec", "faiss", "pgvector"])
def vector_provider(request):
return request.param
@pytest.fixture
def vector_db_id() -> str:
def vector_store_id() -> str:
return f"test-vector-db-{random.randint(1, 100)}"
@ -122,8 +113,9 @@ async def unique_kvstore_config(tmp_path_factory):
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
backend_name = f"kv_vector_{unique_id}"
register_kvstore_backends({backend_name: SqliteKVStoreConfig(db_path=db_path)})
return KVStoreReference(backend=backend_name, namespace=f"vector_io::{unique_id}")
@pytest.fixture(scope="session")
@ -148,7 +140,7 @@ async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
@ -157,8 +149,8 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
await adapter.register_vector_store(
VectorStore(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
@ -170,46 +162,6 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, unique_kvstore_config, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=unique_kvstore_config,
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
@ -226,7 +178,7 @@ async def faiss_vec_index(embedding_dimension):
@pytest.fixture
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = FaissVectorIOConfig(
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = FaissVectorIOAdapter(
config=config,
@ -234,8 +186,8 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
await adapter.register_vector_store(
VectorStore(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider",
embedding_model="test_model",
@ -246,98 +198,6 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
await adapter.shutdown()
@pytest.fixture
def chroma_vec_db_path(tmp_path_factory):
persist_dir = tmp_path_factory.mktemp(f"chroma_{np.random.randint(1e6)}")
return str(persist_dir)
@pytest.fixture
async def chroma_vec_index(chroma_vec_db_path, embedding_dimension):
client = PersistentClient(path=chroma_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
collection = await maybe_await(client.get_or_create_collection(name))
index = ChromaIndex(client=client, collection=collection)
await index.initialize()
yield index
await index.delete()
@pytest.fixture
async def chroma_vec_adapter(chroma_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
config = ChromaVectorIOConfig(
db_path=chroma_vec_db_path,
kvstore=unique_kvstore_config,
)
adapter = ChromaVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"chroma_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def qdrant_vec_db_path(tmp_path_factory):
import uuid
db_path = str(tmp_path_factory.getbasetemp() / f"test_qdrant_{uuid.uuid4()}.db")
return db_path
@pytest.fixture
async def qdrant_vec_adapter(qdrant_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import uuid
config = QdrantVectorIOConfig(
db_path=qdrant_vec_db_path,
kvstore=unique_kvstore_config,
)
adapter = QdrantVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"qdrant_test_collection_{uuid.uuid4()}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture
async def qdrant_vec_index(qdrant_vec_db_path, embedding_dimension):
import uuid
from qdrant_client import AsyncQdrantClient
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantIndex
client = AsyncQdrantClient(path=qdrant_vec_db_path)
collection_name = f"qdrant_test_collection_{uuid.uuid4()}"
index = QdrantIndex(client, collection_name)
yield index
await index.delete()
@pytest.fixture
def mock_psycopg2_connection():
connection = MagicMock()
@ -355,7 +215,7 @@ def mock_psycopg2_connection():
async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
vector_store = VectorStore(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
@ -365,7 +225,7 @@ async def pgvector_vec_index(embedding_dimension, mock_psycopg2_connection):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.execute_values"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
index = PGVectorIndex(vector_store, embedding_dimension, connection, distance_metric="COSINE")
index._test_chunks = []
original_add_chunks = index.add_chunks
@ -393,7 +253,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
db="test_db",
user="test_user",
password="test_password",
kvstore=unique_kvstore_config,
persistence=unique_kvstore_config,
)
adapter = PGVectorVectorIOAdapter(config, mock_inference_api, None)
@ -421,110 +281,41 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
await adapter.initialize()
adapter.conn = mock_conn
async def mock_insert_chunks(vector_db_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
async def mock_insert_chunks(vector_store_id, chunks, ttl_seconds=None):
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
raise ValueError(f"Vector DB {vector_store_id} not found")
await index.insert_chunks(chunks)
adapter.insert_chunks = mock_insert_chunks
async def mock_query_chunks(vector_db_id, query, params=None):
index = await adapter._get_and_cache_vector_db_index(vector_db_id)
async def mock_query_chunks(vector_store_id, query, params=None):
index = await adapter._get_and_cache_vector_store_index(vector_store_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
raise ValueError(f"Vector DB {vector_store_id} not found")
return await index.query_chunks(query, params)
adapter.query_chunks = mock_query_chunks
test_vector_db = VectorDB(
test_vector_store = VectorStore(
identifier=f"pgvector_test_collection_{random.randint(1, 1_000_000)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
await adapter.register_vector_db(test_vector_db)
adapter.test_collection_id = test_vector_db.identifier
await adapter.register_vector_store(test_vector_store)
adapter.test_collection_id = test_vector_store.identifier
yield adapter
await adapter.shutdown()
@pytest.fixture(scope="session")
def weaviate_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_weaviate.db")
return db_path
@pytest.fixture
async def weaviate_vec_index(weaviate_vec_db_path):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
index = WeaviateIndex(client=client, collection_name="Testcollection")
await index.initialize()
yield index
await index.delete()
client.close()
@pytest.fixture
async def weaviate_vec_adapter(weaviate_vec_db_path, unique_kvstore_config, mock_inference_api, embedding_dimension):
import pytest_socket
import weaviate
pytest_socket.enable_socket()
client = weaviate.connect_to_embedded(
hostname="localhost",
port=8080,
grpc_port=50051,
persistence_data_path=weaviate_vec_db_path,
)
config = WeaviateVectorIOConfig(
weaviate_cluster_url="localhost:8080",
weaviate_api_key=None,
kvstore=unique_kvstore_config,
)
adapter = WeaviateVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"weaviate_test_collection_{random.randint(1, 1_000_000)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
client.close()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
vector_provider_dict = {
"milvus": "milvus_vec_adapter",
"faiss": "faiss_vec_adapter",
"sqlite_vec": "sqlite_vec_adapter",
"chroma": "chroma_vec_adapter",
"qdrant": "qdrant_vec_adapter",
"pgvector": "pgvector_vec_adapter",
"weaviate": "weaviate_vec_adapter",
}
return request.getfixturevalue(vector_provider_dict[vector_provider])

View file

@ -1,326 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
pymilvus_mock.RRFRanker = MagicMock
pymilvus_mock.WeightedRanker = MagicMock
pymilvus_mock.AnnSearchRequest = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
async def test_query_hybrid_search_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with RRF reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with RRF reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="rrf",
reranker_params={"impact_factor": 60.0},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
# Check that the request contains both vector and BM25 search requests
reqs = call_args[1]["reqs"]
assert len(reqs) == 2
assert reqs[0].anns_field == "vector"
assert reqs[1].anns_field == "sparse"
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_weighted(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with weighted reranker."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Test hybrid search with weighted reranker
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=2,
score_threshold=0.0,
reranker_type="weighted",
reranker_params={"alpha": 0.7},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
assert len(response.scores) == 2
# Verify hybrid search was called with correct parameters
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None
async def test_query_hybrid_search_default_rrf(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
"""Test hybrid search with default RRF reranker (no reranker_type specified)."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Mock hybrid search results
mock_milvus_client.hybrid_search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
]
]
# Test hybrid search with default reranker (should be RRF)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
query_string = "test query"
response = await milvus_index.query_hybrid(
embedding=query_embedding,
query_string=query_string,
k=1,
score_threshold=0.0,
reranker_type="unknown_type", # Should default to RRF
reranker_params=None, # Should use default impact_factor
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 1
# Verify hybrid search was called with RRF reranker
mock_milvus_client.hybrid_search.assert_called_once()
call_args = mock_milvus_client.hybrid_search.call_args
ranker = call_args[1]["ranker"]
assert ranker is not None

View file

@ -1,138 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
from unittest.mock import patch
import pytest
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex
PGVECTOR_PROVIDER = "pgvector"
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def embedding_dimension():
"""Default embedding dimension for tests."""
return 384
@pytest.fixture
async def pgvector_index(embedding_dimension, mock_psycopg2_connection):
"""Create a PGVectorIndex instance with mocked database connection."""
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
# Use explicit COSINE distance metric for consistent testing
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
return index, cursor
class TestPGVectorIndex:
def test_distance_metric_validation(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="L2")
assert index.distance_metric == "L2"
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID")
def test_get_pgvector_search_function(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric, function in supported_metrics.items():
index.distance_metric = metric
assert index.get_pgvector_search_function() == function
def test_check_distance_metric_availability(self, pgvector_index):
index, cursor = pgvector_index
supported_metrics = index.PGVECTOR_DISTANCE_METRIC_TO_SEARCH_FUNCTION
for metric in supported_metrics:
index.check_distance_metric_availability(metric)
with pytest.raises(ValueError, match="Distance metric 'INVALID' is not supported"):
index.check_distance_metric_availability("INVALID")
def test_constructor_invalid_distance_metric(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
with pytest.raises(ValueError, match="Distance metric 'INVALID_METRIC' is not supported by PGVector"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="INVALID_METRIC")
with pytest.raises(ValueError, match="Supported metrics are:"):
PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="UNKNOWN")
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric="COSINE")
assert index.distance_metric == "COSINE"
except ValueError:
pytest.fail("Valid distance metric 'COSINE' should not raise ValueError")
def test_constructor_all_supported_distance_metrics(self, embedding_dimension, mock_psycopg2_connection):
connection, cursor = mock_psycopg2_connection
vector_db = VectorDB(
identifier="test-vector-db",
embedding_model="test-model",
embedding_dimension=embedding_dimension,
provider_id=PGVECTOR_PROVIDER,
provider_resource_id=f"{PGVECTOR_PROVIDER}:test-vector-db",
)
supported_metrics = ["L2", "L1", "COSINE", "INNER_PRODUCT", "HAMMING", "JACCARD"]
with patch("llama_stack.providers.remote.vector_io.pgvector.pgvector.psycopg2"):
for metric in supported_metrics:
try:
index = PGVectorIndex(vector_db, embedding_dimension, connection, distance_metric=metric)
assert index.distance_metric == metric
expected_operators = {
"L2": "<->",
"L1": "<+>",
"COSINE": "<=>",
"INNER_PRODUCT": "<#>",
"HAMMING": "<~>",
"JACCARD": "<%>",
}
assert index.get_pgvector_search_function() == expected_operators[metric]
except Exception as e:
pytest.fail(f"Valid distance metric '{metric}' should not raise exception: {e}")

View file

@ -11,8 +11,8 @@ import numpy as np
import pytest
from llama_stack.apis.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.datatypes import HealthStatus
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import (
@ -39,12 +39,12 @@ def loop():
@pytest.fixture
def embedding_dimension():
return 384
return 768
@pytest.fixture
def vector_db_id():
return "test_vector_db"
def vector_store_id():
return "test_vector_store"
@pytest.fixture
@ -61,12 +61,12 @@ def sample_embeddings(embedding_dimension):
@pytest.fixture
def mock_vector_db(vector_db_id, embedding_dimension) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "mock_embedding_model"
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = embedding_dimension
return mock_vector_db
def mock_vector_store(vector_store_id, embedding_dimension) -> MagicMock:
mock_vector_store = MagicMock(spec=VectorStore)
mock_vector_store.embedding_model = "mock_embedding_model"
mock_vector_store.identifier = vector_store_id
mock_vector_store.embedding_dimension = embedding_dimension
return mock_vector_store
@pytest.fixture

View file

@ -1,147 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.inference import Inference
from llama_stack.apis.inference.inference import OpenAIEmbeddingData, OpenAIEmbeddingsResponse, OpenAIEmbeddingUsage
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorDB,
VectorDBStore,
)
from llama_stack.providers.inline.vector_io.qdrant.config import (
QdrantVectorIOConfig as InlineQdrantVectorIOConfig,
)
from llama_stack.providers.remote.vector_io.qdrant.qdrant import (
QdrantVectorIOAdapter,
)
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_qdrant.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture
def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig:
kvstore_config = SqliteKVStoreConfig(db_name=os.path.join(tmp_path, "test_kvstore.db"))
return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db"), kvstore=kvstore_config)
@pytest.fixture(scope="session")
def loop():
return asyncio.new_event_loop()
@pytest.fixture
def mock_vector_db(vector_db_id) -> MagicMock:
mock_vector_db = MagicMock(spec=VectorDB)
mock_vector_db.embedding_model = "embedding_model"
mock_vector_db.identifier = vector_db_id
mock_vector_db.embedding_dimension = 384
mock_vector_db.model_dump_json.return_value = (
'{"identifier": "'
+ vector_db_id
+ '", "provider_id": "qdrant", "embedding_model": "embedding_model", "embedding_dimension": 384}'
)
return mock_vector_db
@pytest.fixture
def mock_vector_db_store(mock_vector_db) -> MagicMock:
mock_store = MagicMock(spec=VectorDBStore)
mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db)
return mock_store
@pytest.fixture
def mock_api_service(sample_embeddings):
mock_api_service = MagicMock(spec=Inference)
mock_api_service.openai_embeddings = AsyncMock(
return_value=OpenAIEmbeddingsResponse(
model="mock-embedding-model",
data=[OpenAIEmbeddingData(embedding=sample, index=i) for i, sample in enumerate(sample_embeddings)],
usage=OpenAIEmbeddingUsage(prompt_tokens=10, total_tokens=10),
)
)
return mock_api_service
@pytest.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service, files_api=None)
adapter.vector_db_store = mock_vector_db_store
await adapter.initialize()
yield adapter
await adapter.shutdown()
__QUERY = "Sample query"
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter,
vector_db_id,
sample_chunks,
sample_embeddings,
max_query_chunks,
expected_chunks,
) -> None:
assert qdrant_adapter is not None
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id)
assert index is not None
response = await qdrant_adapter.query_chunks(
query=__QUERY,
vector_db_id=vector_db_id,
params={"max_chunks": max_query_chunks, "mode": "vector"},
)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == expected_chunks
# To by-pass attempt to convert a Mock to JSON
def _prepare_for_json(value: Any) -> str:
return str(value)
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db,
sample_chunks,
) -> None:
# Initially, no collections
vector_db_id = mock_vector_db.identifier
assert len((await qdrant_adapter.client.get_collections()).collections) == 0
# Register does not create a collection
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
await qdrant_adapter.register_vector_db(mock_vector_db)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
# First insert creates the collection
await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks)
assert await qdrant_adapter.client.collection_exists(vector_db_id)
# Unregister deletes the collection
await qdrant_adapter.unregister_vector_db(vector_db_id)
assert not (await qdrant_adapter.client.collection_exists(vector_db_id))
assert len((await qdrant_adapter.client.get_collections()).collections) == 0

View file

@ -12,14 +12,16 @@ import numpy as np
import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
VectorStoreChunkingStrategyAuto,
VectorStoreFileObject,
)
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
from llama_stack.apis.vector_stores import VectorStore
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import VECTOR_DBS_PREFIX
# This test is a unit test for the inline VectorIO providers. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
@ -69,7 +71,7 @@ async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimensio
async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
key = f"{VECTOR_DBS_PREFIX}db1"
dummy = VectorDB(
dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
@ -79,10 +81,10 @@ async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.initialize()
dummy = VectorDB(
dummy = VectorStore(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.register_vector_store(dummy)
await vector_io_adapter.shutdown()
await vector_io_adapter.initialize()
@ -90,26 +92,22 @@ async def test_persistence_across_adapter_restarts(vector_io_adapter):
await vector_io_adapter.shutdown()
async def test_register_and_unregister_vector_db(vector_io_adapter):
async def test_register_and_unregister_vector_store(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB(
dummy = VectorStore(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
)
await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.register_vector_store(dummy)
assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier)
await vector_io_adapter.unregister_vector_store(dummy.identifier)
assert dummy.identifier not in vector_io_adapter.cache
async def test_query_unregistered_raises(vector_io_adapter, vector_provider):
fake_emb = np.zeros(8, dtype=np.float32)
if vector_provider == "chroma":
with pytest.raises(AttributeError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
else:
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("no_such_db", fake_emb)
async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
@ -123,12 +121,43 @@ async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
async def test_insert_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await vector_io_adapter.insert_chunks("db_not_exist", [])
async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
"""Ensure no KeyError when document_id is missing or in different places."""
from llama_stack.apis.vector_io import Chunk, ChunkMetadata
fake_index = AsyncMock()
vector_io_adapter.cache["db1"] = fake_index
# Various document_id scenarios that shouldn't crash
chunks = [
Chunk(content="has doc_id in metadata", metadata={"document_id": "doc-1"}),
Chunk(content="no doc_id anywhere", metadata={"source": "test"}),
Chunk(content="doc_id in chunk_metadata", chunk_metadata=ChunkMetadata(document_id="doc-3")),
]
# Should work without KeyError
await vector_io_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once()
async def test_document_id_with_invalid_type_raises_error():
"""Ensure TypeError is raised when document_id is not a string."""
from llama_stack.apis.vector_io import Chunk
# Integer document_id should raise TypeError
chunk = Chunk(content="test", metadata={"document_id": 12345})
with pytest.raises(TypeError) as exc_info:
_ = chunk.document_id
assert "metadata['document_id'] must be a string" in str(exc_info.value)
assert "got int" in str(exc_info.value)
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
@ -141,7 +170,7 @@ async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter
async def test_query_chunks_missing_db_raises(vector_io_adapter):
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
vector_io_adapter._get_and_cache_vector_store_index = AsyncMock(return_value=None)
with pytest.raises(ValueError):
await vector_io_adapter.query_chunks("db_missing", "q", None)
@ -153,7 +182,7 @@ async def test_save_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -169,7 +198,7 @@ async def test_update_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -185,7 +214,7 @@ async def test_delete_openai_vector_store(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -200,7 +229,7 @@ async def test_load_openai_vector_stores(vector_io_adapter):
"id": store_id,
"name": "Test Store",
"description": "A test OpenAI vector store",
"vector_db_id": "test_db",
"vector_store_id": "test_db",
"embedding_model": "test_model",
}
@ -330,8 +359,7 @@ async def test_create_vector_store_file_batch(vector_io_adapter):
vector_io_adapter._process_file_batch_async = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
assert batch.vector_store_id == store_id
@ -358,8 +386,7 @@ async def test_retrieve_vector_store_file_batch(vector_io_adapter):
# Create batch first
created_batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Retrieve batch
@ -392,8 +419,7 @@ async def test_cancel_vector_store_file_batch(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Cancel batch
@ -438,8 +464,7 @@ async def test_list_files_in_vector_store_file_batch(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# List files
@ -459,7 +484,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
with pytest.raises(VectorStoreNotFoundError):
await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id="nonexistent",
file_ids=["file_1"],
params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"]),
)
# Setup store for remaining tests
@ -476,8 +501,7 @@ async def test_file_batch_validation_errors(vector_io_adapter):
# Test wrong vector store for batch
vector_io_adapter.openai_attach_file_to_vector_store = AsyncMock()
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=["file_1"],
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
)
# Create wrong_store so it exists but the batch doesn't belong to it
@ -524,8 +548,7 @@ async def test_file_batch_pagination(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Test pagination with limit
@ -597,8 +620,7 @@ async def test_file_batch_status_filtering(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Test filtering by completed status
@ -640,8 +662,7 @@ async def test_cancel_completed_batch_fails(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Manually update status to completed
@ -675,8 +696,7 @@ async def test_file_batch_persistence_across_restarts(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
batch_id = batch.id
@ -731,8 +751,7 @@ async def test_cancelled_batch_persists_in_storage(vector_io_adapter):
# Create batch
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
batch_id = batch.id
@ -779,10 +798,10 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
# Create multiple batches
batch1 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_1"]
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_1"])
)
batch2 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_2"]
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_2"])
)
# Complete one batch (should persist with completed status)
@ -795,7 +814,7 @@ async def test_only_in_progress_batches_resumed(vector_io_adapter):
# Create a third batch that stays in progress
batch3 = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id, file_ids=["file_3"]
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["file_3"])
)
# Simulate restart - clear memory and reload from persistence
@ -956,8 +975,7 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
file_ids = [f"file_{i}" for i in range(8)] # 8 files, but limit should be 5
batch = await vector_io_adapter.openai_create_vector_store_file_batch(
vector_store_id=store_id,
file_ids=file_ids,
vector_store_id=store_id, params=OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=file_ids)
)
# Give time for the semaphore logic to start processing files
@ -975,3 +993,130 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
assert batch.status == "in_progress"
assert batch.file_counts.total == 8
assert batch.file_counts.in_progress == 8
async def test_embedding_config_from_metadata(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from metadata."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with embedding config in metadata
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={
"embedding_model": "test-embedding-model",
"embedding_dimension": "512",
},
model_extra={},
)
await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorStore was registered with correct embedding config from metadata
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "test-embedding-model"
assert call_args.embedding_dimension == 512
async def test_embedding_config_from_extra_body(vector_io_adapter):
"""Test that embedding configuration is correctly extracted from extra_body when metadata is empty."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with embedding config in extra_body only (metadata has no embedding_model)
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={}, # Empty metadata to ensure extra_body is used
**{
"embedding_model": "extra-body-model",
"embedding_dimension": 1024,
},
)
await vector_io_adapter.openai_create_vector_store(params)
# Verify VectorStore was registered with correct embedding config from extra_body
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "extra-body-model"
assert call_args.embedding_dimension == 1024
async def test_embedding_config_consistency_check_passes(vector_io_adapter):
"""Test that consistent embedding config in both metadata and extra_body passes validation."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with consistent embedding config in both metadata and extra_body
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={
"embedding_model": "consistent-model",
"embedding_dimension": "768",
},
**{
"embedding_model": "consistent-model",
"embedding_dimension": 768,
},
)
await vector_io_adapter.openai_create_vector_store(params)
# Should not raise any error and use metadata config
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "consistent-model"
assert call_args.embedding_dimension == 768
async def test_embedding_config_defaults_when_missing(vector_io_adapter):
"""Test that embedding dimension defaults to 768 when not provided."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Test with only embedding model, no dimension (metadata empty to use extra_body)
params = OpenAICreateVectorStoreRequestWithExtraBody(
name="test_store",
metadata={}, # Empty metadata to ensure extra_body is used
**{
"embedding_model": "model-without-dimension",
},
)
await vector_io_adapter.openai_create_vector_store(params)
# Should default to 768 dimensions
vector_io_adapter.register_vector_store.assert_called_once()
call_args = vector_io_adapter.register_vector_store.call_args[0][0]
assert call_args.embedding_model == "model-without-dimension"
assert call_args.embedding_dimension == 768
async def test_embedding_config_required_model_missing(vector_io_adapter):
"""Test that missing embedding model raises error."""
# Mock register_vector_store to avoid actual registration
vector_io_adapter.register_vector_store = AsyncMock()
# Set provider_id attribute for the adapter
vector_io_adapter.__provider_id__ = "test_provider"
# Mock the default model lookup to return None (no default model available)
vector_io_adapter._get_default_embedding_model_and_dimension = AsyncMock(return_value=None)
# Test with no embedding model provided
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test_store", metadata={})
with pytest.raises(ValueError, match="embedding_model is required"):
await vector_io_adapter.openai_create_vector_store(params)