mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-09 13:14:39 +00:00
Merge branch 'main' into chroma
This commit is contained in:
commit
c66ebae9b6
207 changed files with 15490 additions and 7927 deletions
347
tests/unit/providers/agent/test_agent_meta_reference.py
Normal file
347
tests/unit/providers/agent/test_agent_meta_reference.py
Normal file
|
@ -0,0 +1,347 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import Session
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.providers.inline.agents.meta_reference.persistence import (
|
||||
AgentPersistence,
|
||||
AgentSessionInfo,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_kvstore():
|
||||
return AsyncMock(spec=KVStore)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_policy():
|
||||
return []
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_persistence(mock_kvstore, mock_policy):
|
||||
return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session():
|
||||
return AgentSessionInfo(
|
||||
session_id="session-123",
|
||||
session_name="Test Session",
|
||||
started_at=datetime.now(UTC),
|
||||
owner=User(principal="user-123", attributes=None),
|
||||
turns=[],
|
||||
identifier="test-session",
|
||||
type="session",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_session_json(sample_session):
|
||||
return sample_session.model_dump_json()
|
||||
|
||||
|
||||
class TestAgentPersistenceListSessions:
|
||||
def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None):
|
||||
"""Helper to setup mock kvstore with sessions, turns, and custom/invalid data
|
||||
|
||||
Args:
|
||||
mock_kvstore: The mock KVStore object
|
||||
session_keys: List of session keys or dict mapping keys to custom session data
|
||||
turn_keys: List of turn keys or dict mapping keys to custom turn data
|
||||
invalid_keys: Dict mapping keys to invalid/corrupt data
|
||||
custom_data: Additional custom data to add to the mock responses
|
||||
"""
|
||||
all_keys = []
|
||||
mock_data = {}
|
||||
|
||||
# session keys
|
||||
if session_keys:
|
||||
if isinstance(session_keys, dict):
|
||||
all_keys.extend(session_keys.keys())
|
||||
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()})
|
||||
else:
|
||||
all_keys.extend(session_keys)
|
||||
for key in session_keys:
|
||||
session_id = key.split(":")[-1]
|
||||
mock_data[key] = json.dumps(
|
||||
{
|
||||
"session_id": session_id,
|
||||
"session_name": f"Session {session_id}",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
)
|
||||
|
||||
# turn keys
|
||||
if turn_keys:
|
||||
if isinstance(turn_keys, dict):
|
||||
all_keys.extend(turn_keys.keys())
|
||||
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()})
|
||||
else:
|
||||
all_keys.extend(turn_keys)
|
||||
for key in turn_keys:
|
||||
parts = key.split(":")
|
||||
session_id = parts[-2]
|
||||
turn_id = parts[-1]
|
||||
mock_data[key] = json.dumps(
|
||||
{
|
||||
"turn_id": turn_id,
|
||||
"session_id": session_id,
|
||||
"input_messages": [],
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
if invalid_keys:
|
||||
all_keys.extend(invalid_keys.keys())
|
||||
mock_data.update(invalid_keys)
|
||||
|
||||
if custom_data:
|
||||
mock_data.update(custom_data)
|
||||
|
||||
values_list = list(mock_data.values())
|
||||
mock_kvstore.values_in_range.return_value = values_list
|
||||
|
||||
async def mock_get(key):
|
||||
return mock_data.get(key)
|
||||
|
||||
mock_kvstore.get.side_effect = mock_get
|
||||
|
||||
return mock_data
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"scenario",
|
||||
[
|
||||
{
|
||||
# from this issue: https://github.com/meta-llama/llama-stack/issues/3048
|
||||
"name": "reported_bug",
|
||||
"session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad"
|
||||
],
|
||||
"expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
|
||||
},
|
||||
{
|
||||
"name": "basic_filtering",
|
||||
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
|
||||
"turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"],
|
||||
"expected_sessions": ["session-1", "session-2"],
|
||||
},
|
||||
{
|
||||
"name": "multiple_turns_per_session",
|
||||
"session_keys": ["session:test-agent-123:session-456"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:session-456:turn-789",
|
||||
"session:test-agent-123:session-456:turn-790",
|
||||
],
|
||||
"expected_sessions": ["session-456"],
|
||||
},
|
||||
{
|
||||
"name": "multiple_sessions_with_turns",
|
||||
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
|
||||
"turn_keys": [
|
||||
"session:test-agent-123:session-1:turn-1",
|
||||
"session:test-agent-123:session-1:turn-2",
|
||||
"session:test-agent-123:session-2:turn-3",
|
||||
],
|
||||
"expected_sessions": ["session-1", "session-2"],
|
||||
},
|
||||
],
|
||||
)
|
||||
async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario):
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"])
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == len(scenario["expected_sessions"])
|
||||
session_ids = {s.session_id for s in result}
|
||||
for expected_id in scenario["expected_sessions"]:
|
||||
assert expected_id in session_ids
|
||||
|
||||
# no errors should be logged
|
||||
mock_log.error.assert_not_called()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"error_scenario",
|
||||
[
|
||||
{
|
||||
"name": "invalid_json",
|
||||
"valid_keys": ["session:test-agent-123:valid-session"],
|
||||
"invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"},
|
||||
"expected_valid_sessions": ["valid-session"],
|
||||
"expected_error_count": 1,
|
||||
},
|
||||
{
|
||||
"name": "missing_fields",
|
||||
"valid_keys": ["session:test-agent-123:valid-session"],
|
||||
"invalid_data": {
|
||||
"session:test-agent-123:invalid-schema": json.dumps(
|
||||
{
|
||||
"session_id": "invalid-schema",
|
||||
"session_name": "Missing Fields",
|
||||
# missing `started_at` and `turns`
|
||||
}
|
||||
)
|
||||
},
|
||||
"expected_valid_sessions": ["valid-session"],
|
||||
"expected_error_count": 1,
|
||||
},
|
||||
{
|
||||
"name": "multiple_invalid",
|
||||
"valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"],
|
||||
"invalid_data": {
|
||||
"session:test-agent-123:corrupted-json": "not-valid-json{",
|
||||
"session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}),
|
||||
},
|
||||
"expected_valid_sessions": ["valid-session-1", "valid-session-2"],
|
||||
"expected_error_count": 2,
|
||||
},
|
||||
],
|
||||
)
|
||||
async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario):
|
||||
session_keys = {}
|
||||
for key in error_scenario["valid_keys"]:
|
||||
session_id = key.split(":")[-1]
|
||||
session_keys[key] = {
|
||||
"session_id": session_id,
|
||||
"session_name": f"Valid {session_id}",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"])
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
# only valid sessions should be returned
|
||||
assert len(result) == len(error_scenario["expected_valid_sessions"])
|
||||
session_ids = {s.session_id for s in result}
|
||||
for expected_id in error_scenario["expected_valid_sessions"]:
|
||||
assert expected_id in session_ids
|
||||
|
||||
# error should be logged
|
||||
assert mock_log.error.call_count > 0
|
||||
assert mock_log.error.call_count == error_scenario["expected_error_count"]
|
||||
|
||||
async def test_list_sessions_empty(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.return_value = []
|
||||
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert result == []
|
||||
mock_kvstore.values_in_range.assert_called_once_with(
|
||||
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
|
||||
)
|
||||
|
||||
async def test_list_sessions_properties(self, agent_persistence, mock_kvstore):
|
||||
session_data = {
|
||||
"session_id": "session-123",
|
||||
"session_name": "Test Session",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"owner": {"principal": "user-123", "attributes": None},
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data})
|
||||
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], Session)
|
||||
assert result[0].session_id == "session-123"
|
||||
assert result[0].session_name == "Test Session"
|
||||
assert result[0].turns == []
|
||||
assert hasattr(result[0], "started_at")
|
||||
|
||||
async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.side_effect = Exception("KVStore error")
|
||||
|
||||
with pytest.raises(Exception, match="KVStore error"):
|
||||
await agent_persistence.list_sessions()
|
||||
|
||||
async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore):
|
||||
# tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048
|
||||
session_data = {
|
||||
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
|
||||
"session_name": "Test Session",
|
||||
"started_at": datetime.now(UTC).isoformat(),
|
||||
"turns": [],
|
||||
}
|
||||
|
||||
turn_data = {
|
||||
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
|
||||
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
|
||||
"input_messages": [
|
||||
{"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None}
|
||||
],
|
||||
"steps": [
|
||||
{
|
||||
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
|
||||
"step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132",
|
||||
"started_at": "2025-08-05T14:31:50.000484Z",
|
||||
"completed_at": "2025-08-05T14:31:51.303691Z",
|
||||
"step_type": "inference",
|
||||
"model_response": {
|
||||
"role": "assistant",
|
||||
"content": "OK, I can create a cluster named 'persistence01' for you.",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [],
|
||||
},
|
||||
}
|
||||
],
|
||||
"output_message": {
|
||||
"role": "assistant",
|
||||
"content": "OK, I can create a cluster named 'persistence01' for you.",
|
||||
"stop_reason": "end_of_turn",
|
||||
"tool_calls": [],
|
||||
},
|
||||
"output_attachments": [],
|
||||
"started_at": "2025-08-05T14:31:49.999950Z",
|
||||
"completed_at": "2025-08-05T14:31:51.305384Z",
|
||||
}
|
||||
|
||||
mock_data = {
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data),
|
||||
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps(
|
||||
turn_data
|
||||
),
|
||||
}
|
||||
|
||||
mock_kvstore.values_in_range.return_value = list(mock_data.values())
|
||||
|
||||
async def mock_get(key):
|
||||
return mock_data.get(key)
|
||||
|
||||
mock_kvstore.get.side_effect = mock_get
|
||||
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
|
||||
result = await agent_persistence.list_sessions()
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"
|
||||
|
||||
# confirm no errors logged
|
||||
mock_log.error.assert_not_called()
|
||||
|
||||
async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore):
|
||||
mock_kvstore.values_in_range.return_value = []
|
||||
|
||||
await agent_persistence.list_sessions()
|
||||
|
||||
mock_kvstore.values_in_range.assert_called_once_with(
|
||||
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
|
||||
)
|
|
@ -41,7 +41,7 @@ from llama_stack.apis.inference import (
|
|||
)
|
||||
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
@ -136,9 +136,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
input=input_text,
|
||||
model=model,
|
||||
temperature=0.1,
|
||||
stream=True, # Enable streaming to test content part events
|
||||
)
|
||||
|
||||
# Verify
|
||||
# For streaming response, collect all chunks
|
||||
chunks = [chunk async for chunk in result]
|
||||
|
||||
mock_inference_api.openai_chat_completion.assert_called_once_with(
|
||||
model=model,
|
||||
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
|
||||
|
@ -147,11 +150,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
|
|||
stream=True,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
# Should have content part events for text streaming
|
||||
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
|
||||
assert len(chunks) >= 4
|
||||
assert chunks[0].type == "response.created"
|
||||
|
||||
# Check for content part events
|
||||
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
|
||||
content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"]
|
||||
text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"]
|
||||
|
||||
assert len(content_part_added_events) >= 1, "Should have content_part.added event for text"
|
||||
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
|
||||
assert len(text_delta_events) >= 1, "Should have text delta events"
|
||||
|
||||
# Verify final event is completion
|
||||
assert chunks[-1].type == "response.completed"
|
||||
|
||||
# When streaming, the final response is in the last chunk
|
||||
final_response = chunks[-1].response
|
||||
assert final_response.model == model
|
||||
assert len(final_response.output) == 1
|
||||
assert isinstance(final_response.output[0], OpenAIResponseMessage)
|
||||
|
||||
openai_responses_impl.responses_store.store_response_object.assert_called_once()
|
||||
assert result.model == model
|
||||
assert len(result.output) == 1
|
||||
assert isinstance(result.output[0], OpenAIResponseMessage)
|
||||
assert result.output[0].content[0].text == "Dublin"
|
||||
assert final_response.output[0].content[0].text == "Dublin"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
|
||||
|
@ -272,7 +296,11 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
|
||||
# Check that we got the content from our mocked tool execution result
|
||||
chunks = [chunk async for chunk in result]
|
||||
assert len(chunks) == 2 # Should have response.created and response.completed
|
||||
|
||||
# Verify event types
|
||||
# Should have: response.created, output_item.added, function_call_arguments.delta,
|
||||
# function_call_arguments.done, output_item.done, response.completed
|
||||
assert len(chunks) == 6
|
||||
|
||||
# Verify inference API was called correctly (after iterating over result)
|
||||
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
|
||||
|
@ -284,11 +312,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
|
|||
assert chunks[0].type == "response.created"
|
||||
assert len(chunks[0].response.output) == 0
|
||||
|
||||
# Check streaming events
|
||||
assert chunks[1].type == "response.output_item.added"
|
||||
assert chunks[2].type == "response.function_call_arguments.delta"
|
||||
assert chunks[3].type == "response.function_call_arguments.done"
|
||||
assert chunks[4].type == "response.output_item.done"
|
||||
|
||||
# Check response.completed event (should have the tool call)
|
||||
assert chunks[1].type == "response.completed"
|
||||
assert len(chunks[1].response.output) == 1
|
||||
assert chunks[1].response.output[0].type == "function_call"
|
||||
assert chunks[1].response.output[0].name == "get_weather"
|
||||
assert chunks[5].type == "response.completed"
|
||||
assert len(chunks[5].response.output) == 1
|
||||
assert chunks[5].response.output[0].type == "function_call"
|
||||
assert chunks[5].response.output[0].name == "get_weather"
|
||||
|
||||
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
|
|
|
@ -0,0 +1,310 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.agents.openai_responses import (
|
||||
OpenAIResponseInputFunctionToolCallOutput,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolWebSearch,
|
||||
OpenAIResponseMessage,
|
||||
OpenAIResponseOutputMessageContentOutputText,
|
||||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
)
|
||||
from llama_stack.apis.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIChoice,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIResponseFormatJSONObject,
|
||||
OpenAIResponseFormatJSONSchema,
|
||||
OpenAIResponseFormatText,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
convert_chat_choice_to_response_message,
|
||||
convert_response_content_to_chat_content,
|
||||
convert_response_input_to_chat_messages,
|
||||
convert_response_text_to_chat_response_format,
|
||||
get_message_type_by_role,
|
||||
is_function_tool_call,
|
||||
)
|
||||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(content="Test message"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
result = await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert result.role == "assistant"
|
||||
assert result.status == "completed"
|
||||
assert len(result.content) == 1
|
||||
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
|
||||
assert result.content[0].text == "Test message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_param_content(self):
|
||||
choice = OpenAIChoice(
|
||||
message=OpenAIAssistantMessageParam(
|
||||
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
|
||||
),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await convert_chat_choice_to_response_message(choice)
|
||||
|
||||
assert "does not yet support output content type" in str(exc_info.value)
|
||||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
assert result == "Simple string"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_content_parts(self):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
|
||||
]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[0].text == "First part"
|
||||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_image_content(self):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert result[0].image_url.url == "https://example.com/image.jpg"
|
||||
assert result[0].image_url.detail == "high"
|
||||
|
||||
|
||||
class TestConvertResponseInputToChatMessages:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_string_input(self):
|
||||
result = await convert_response_input_to_chat_messages("User message")
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
assert result[0].content == "User message"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_function_tool_call_output(self):
|
||||
input_items = [
|
||||
OpenAIResponseInputFunctionToolCallOutput(
|
||||
output="Tool output",
|
||||
call_id="call_123",
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIToolMessageParam)
|
||||
assert result[0].content == "Tool output"
|
||||
assert result[0].tool_call_id == "call_123"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_function_tool_call(self):
|
||||
input_items = [
|
||||
OpenAIResponseOutputMessageFunctionToolCall(
|
||||
call_id="call_456",
|
||||
name="test_function",
|
||||
arguments='{"param": "value"}',
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIAssistantMessageParam)
|
||||
assert len(result[0].tool_calls) == 1
|
||||
assert result[0].tool_calls[0].id == "call_456"
|
||||
assert result[0].tool_calls[0].function.name == "test_function"
|
||||
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_response_message(self):
|
||||
input_items = [
|
||||
OpenAIResponseMessage(
|
||||
role="user",
|
||||
content=[OpenAIResponseInputMessageContentText(text="User text")],
|
||||
)
|
||||
]
|
||||
|
||||
result = await convert_response_input_to_chat_messages(input_items)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIUserMessageParam)
|
||||
# Content should be converted to chat content format
|
||||
assert len(result[0].content) == 1
|
||||
assert result[0].content[0].text == "User text"
|
||||
|
||||
|
||||
class TestConvertResponseTextToChatResponseFormat:
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_text_format(self):
|
||||
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_json_object_format(self):
|
||||
text = OpenAIResponseText(format={"type": "json_object"})
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONObject)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_json_schema_format(self):
|
||||
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
|
||||
text = OpenAIResponseText(
|
||||
format={
|
||||
"type": "json_schema",
|
||||
"name": "test_schema",
|
||||
"schema": schema_def,
|
||||
}
|
||||
)
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatJSONSchema)
|
||||
assert result.json_schema["name"] == "test_schema"
|
||||
assert result.json_schema["schema"] == schema_def
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_text_format(self):
|
||||
text = OpenAIResponseText()
|
||||
result = await convert_response_text_to_chat_response_format(text)
|
||||
|
||||
assert isinstance(result, OpenAIResponseFormatText)
|
||||
assert result.type == "text"
|
||||
|
||||
|
||||
class TestGetMessageTypeByRole:
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_role(self):
|
||||
result = await get_message_type_by_role("user")
|
||||
assert result == OpenAIUserMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_system_role(self):
|
||||
result = await get_message_type_by_role("system")
|
||||
assert result == OpenAISystemMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_assistant_role(self):
|
||||
result = await get_message_type_by_role("assistant")
|
||||
assert result == OpenAIAssistantMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_developer_role(self):
|
||||
result = await get_message_type_by_role("developer")
|
||||
assert result == OpenAIDeveloperMessageParam
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_role(self):
|
||||
result = await get_message_type_by_role("unknown")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestIsFunctionToolCall:
|
||||
def test_is_function_tool_call_true(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="test_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is True
|
||||
|
||||
def test_is_function_tool_call_false_different_name(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="other_function",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_no_function(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=None,
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolFunction(
|
||||
type="function", name="test_function", parameters={"type": "object", "properties": {}}
|
||||
),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
||||
|
||||
def test_is_function_tool_call_false_wrong_type(self):
|
||||
tool_call = OpenAIChatCompletionToolCall(
|
||||
index=0,
|
||||
id="call_123",
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name="web_search",
|
||||
arguments="{}",
|
||||
),
|
||||
)
|
||||
tools = [
|
||||
OpenAIResponseInputToolWebSearch(type="web_search"),
|
||||
]
|
||||
|
||||
result = is_function_tool_call(tool_call, tools)
|
||||
assert result is False
|
753
tests/unit/providers/batches/test_reference.py
Normal file
753
tests/unit/providers/batches/test_reference.py
Normal file
|
@ -0,0 +1,753 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""
|
||||
Test suite for the reference implementation of the Batches API.
|
||||
|
||||
The tests are categorized and outlined below, keep this updated:
|
||||
|
||||
- Batch creation with various parameters and validation:
|
||||
* test_create_and_retrieve_batch_success (positive)
|
||||
* test_create_batch_without_metadata (positive)
|
||||
* test_create_batch_completion_window (negative)
|
||||
* test_create_batch_invalid_endpoints (negative)
|
||||
* test_create_batch_invalid_metadata (negative)
|
||||
|
||||
- Batch retrieval and error handling for non-existent batches:
|
||||
* test_retrieve_batch_not_found (negative)
|
||||
|
||||
- Batch cancellation with proper status transitions:
|
||||
* test_cancel_batch_success (positive)
|
||||
* test_cancel_batch_invalid_statuses (negative)
|
||||
* test_cancel_batch_not_found (negative)
|
||||
|
||||
- Batch listing with pagination and filtering:
|
||||
* test_list_batches_empty (positive)
|
||||
* test_list_batches_single_batch (positive)
|
||||
* test_list_batches_multiple_batches (positive)
|
||||
* test_list_batches_with_limit (positive)
|
||||
* test_list_batches_with_pagination (positive)
|
||||
* test_list_batches_invalid_after (negative)
|
||||
|
||||
- Data persistence in the underlying key-value store:
|
||||
* test_kvstore_persistence (positive)
|
||||
|
||||
- Batch processing concurrency control:
|
||||
* test_max_concurrent_batches (positive)
|
||||
|
||||
- Input validation testing (direct _validate_input method tests):
|
||||
* test_validate_input_file_not_found (negative)
|
||||
* test_validate_input_file_exists_empty_content (positive)
|
||||
* test_validate_input_file_mixed_valid_invalid_json (mixed)
|
||||
* test_validate_input_invalid_model (negative)
|
||||
* test_validate_input_url_mismatch (negative)
|
||||
* test_validate_input_multiple_errors_per_request (negative)
|
||||
* test_validate_input_invalid_request_format (negative)
|
||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||
|
||||
The tests use temporary SQLite databases for isolation and mock external
|
||||
dependencies like inference, files, and models APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.batches import BatchObject
|
||||
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
||||
|
||||
class TestReferenceBatchesImpl:
|
||||
"""Test the reference implementation of the Batches API."""
|
||||
|
||||
@pytest.fixture
|
||||
async def provider(self):
|
||||
"""Create a test provider instance with temporary database."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_batches.db"
|
||||
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
|
||||
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
|
||||
|
||||
# Create kvstore and mock APIs
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
kvstore = await kvstore_impl(config.kvstore)
|
||||
mock_inference = AsyncMock()
|
||||
mock_files = AsyncMock()
|
||||
mock_models = AsyncMock()
|
||||
|
||||
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
|
||||
await provider.initialize()
|
||||
|
||||
# unit tests should not require background processing
|
||||
provider.process_batches = False
|
||||
|
||||
yield provider
|
||||
|
||||
await provider.shutdown()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_batch_data(self):
|
||||
"""Sample batch data for testing."""
|
||||
return {
|
||||
"input_file_id": "file_abc123",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {"test": "true", "priority": "high"},
|
||||
}
|
||||
|
||||
def _validate_batch_type(self, batch, expected_metadata=None):
|
||||
"""
|
||||
Helper function to validate batch object structure and field types.
|
||||
|
||||
Note: This validates the direct BatchObject from the provider, not the
|
||||
client library response which has a different structure.
|
||||
|
||||
Args:
|
||||
batch: The BatchObject instance to validate.
|
||||
expected_metadata: Optional expected metadata dictionary to validate against.
|
||||
"""
|
||||
assert isinstance(batch.id, str)
|
||||
assert isinstance(batch.completion_window, str)
|
||||
assert isinstance(batch.created_at, int)
|
||||
assert isinstance(batch.endpoint, str)
|
||||
assert isinstance(batch.input_file_id, str)
|
||||
assert batch.object == "batch"
|
||||
assert batch.status in [
|
||||
"validating",
|
||||
"failed",
|
||||
"in_progress",
|
||||
"finalizing",
|
||||
"completed",
|
||||
"expired",
|
||||
"cancelling",
|
||||
"cancelled",
|
||||
]
|
||||
|
||||
if expected_metadata is not None:
|
||||
assert batch.metadata == expected_metadata
|
||||
|
||||
timestamp_fields = [
|
||||
"cancelled_at",
|
||||
"cancelling_at",
|
||||
"completed_at",
|
||||
"expired_at",
|
||||
"expires_at",
|
||||
"failed_at",
|
||||
"finalizing_at",
|
||||
"in_progress_at",
|
||||
]
|
||||
for field in timestamp_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
|
||||
|
||||
file_id_fields = ["error_file_id", "output_file_id"]
|
||||
for field in file_id_fields:
|
||||
field_value = getattr(batch, field, None)
|
||||
if field_value is not None:
|
||||
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
|
||||
|
||||
if hasattr(batch, "request_counts") and batch.request_counts is not None:
|
||||
assert isinstance(batch.request_counts.completed, int), (
|
||||
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.failed, int), (
|
||||
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
|
||||
)
|
||||
assert isinstance(batch.request_counts.total, int), (
|
||||
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
|
||||
)
|
||||
|
||||
if hasattr(batch, "errors") and batch.errors is not None:
|
||||
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
|
||||
|
||||
if hasattr(batch.errors, "data") and batch.errors.data is not None:
|
||||
assert isinstance(batch.errors.data, list), (
|
||||
f"errors.data should be list or None, got {type(batch.errors.data)}"
|
||||
)
|
||||
|
||||
for i, error_item in enumerate(batch.errors.data):
|
||||
assert isinstance(error_item, dict), (
|
||||
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "code") and error_item.code is not None:
|
||||
assert isinstance(error_item.code, str), (
|
||||
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "line") and error_item.line is not None:
|
||||
assert isinstance(error_item.line, int), (
|
||||
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "message") and error_item.message is not None:
|
||||
assert isinstance(error_item.message, str), (
|
||||
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
|
||||
)
|
||||
|
||||
if hasattr(error_item, "param") and error_item.param is not None:
|
||||
assert isinstance(error_item.param, str), (
|
||||
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
|
||||
)
|
||||
|
||||
if hasattr(batch.errors, "object") and batch.errors.object is not None:
|
||||
assert isinstance(batch.errors.object, str), (
|
||||
f"errors.object should be str or None, got {type(batch.errors.object)}"
|
||||
)
|
||||
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
|
||||
|
||||
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch creation and retrieval."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert created_batch.id.startswith("batch_")
|
||||
assert len(created_batch.id) > 13
|
||||
assert created_batch.object == "batch"
|
||||
assert created_batch.endpoint == sample_batch_data["endpoint"]
|
||||
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
|
||||
assert created_batch.completion_window == sample_batch_data["completion_window"]
|
||||
assert created_batch.status == "validating"
|
||||
assert created_batch.metadata == sample_batch_data["metadata"]
|
||||
assert isinstance(created_batch.created_at, int)
|
||||
assert created_batch.created_at > 0
|
||||
|
||||
retrieved_batch = await provider.retrieve_batch(created_batch.id)
|
||||
|
||||
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
|
||||
|
||||
assert retrieved_batch.id == created_batch.id
|
||||
assert retrieved_batch.input_file_id == created_batch.input_file_id
|
||||
assert retrieved_batch.endpoint == created_batch.endpoint
|
||||
assert retrieved_batch.status == created_batch.status
|
||||
assert retrieved_batch.metadata == created_batch.metadata
|
||||
|
||||
async def test_create_batch_without_metadata(self, provider):
|
||||
"""Test batch creation without optional metadata."""
|
||||
batch = await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
assert batch.metadata is None
|
||||
|
||||
async def test_create_batch_completion_window(self, provider):
|
||||
"""Test batch creation with invalid completion window."""
|
||||
with pytest.raises(ValueError, match="Invalid completion_window"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"endpoint",
|
||||
[
|
||||
"/v1/embeddings",
|
||||
"/v1/completions",
|
||||
"/v1/invalid/endpoint",
|
||||
"",
|
||||
],
|
||||
)
|
||||
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
|
||||
"""Test batch creation with various invalid endpoints."""
|
||||
with pytest.raises(ValueError, match="Invalid endpoint"):
|
||||
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
|
||||
|
||||
async def test_create_batch_invalid_metadata(self, provider):
|
||||
"""Test that batch creation fails with invalid metadata."""
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={123: "invalid_key"}, # Non-string key
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="should be a valid string"):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_123",
|
||||
endpoint="/v1/chat/completions",
|
||||
completion_window="24h",
|
||||
metadata={"valid_key": 456}, # Non-string value
|
||||
)
|
||||
|
||||
async def test_retrieve_batch_not_found(self, provider):
|
||||
"""Test error when retrieving non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.retrieve_batch("nonexistent_batch")
|
||||
|
||||
async def test_cancel_batch_success(self, provider, sample_batch_data):
|
||||
"""Test successful batch cancellation."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
assert created_batch.status == "validating"
|
||||
|
||||
cancelled_batch = await provider.cancel_batch(created_batch.id)
|
||||
|
||||
assert cancelled_batch.id == created_batch.id
|
||||
assert cancelled_batch.status in ["cancelling", "cancelled"]
|
||||
assert isinstance(cancelled_batch.cancelling_at, int)
|
||||
assert cancelled_batch.cancelling_at >= created_batch.created_at
|
||||
|
||||
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
|
||||
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
|
||||
"""Test error when cancelling batch in final states."""
|
||||
provider.process_batches = False
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
# directly update status in kvstore
|
||||
await provider._update_batch(created_batch.id, status=status)
|
||||
|
||||
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
|
||||
await provider.cancel_batch(created_batch.id)
|
||||
|
||||
async def test_cancel_batch_not_found(self, provider):
|
||||
"""Test error when cancelling non-existent batch."""
|
||||
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
|
||||
await provider.cancel_batch("nonexistent_batch")
|
||||
|
||||
async def test_list_batches_empty(self, provider):
|
||||
"""Test listing batches when none exist."""
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert response.object == "list"
|
||||
assert response.data == []
|
||||
assert response.first_id is None
|
||||
assert response.last_id is None
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_single_batch(self, provider, sample_batch_data):
|
||||
"""Test listing batches with single batch."""
|
||||
created_batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 1
|
||||
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
|
||||
assert response.data[0].id == created_batch.id
|
||||
assert response.first_id == created_batch.id
|
||||
assert response.last_id == created_batch.id
|
||||
assert response.has_more is False
|
||||
|
||||
async def test_list_batches_multiple_batches(self, provider):
|
||||
"""Test listing multiple batches."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches()
|
||||
|
||||
assert len(response.data) == 3
|
||||
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids == expected_ids
|
||||
assert response.has_more is False
|
||||
|
||||
assert response.first_id in expected_ids
|
||||
assert response.last_id in expected_ids
|
||||
|
||||
async def test_list_batches_with_limit(self, provider):
|
||||
"""Test listing batches with limit parameter."""
|
||||
batches = [
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
response = await provider.list_batches(limit=2)
|
||||
|
||||
assert len(response.data) == 2
|
||||
assert response.has_more is True
|
||||
assert response.first_id == response.data[0].id
|
||||
assert response.last_id == response.data[1].id
|
||||
batch_ids = {batch.id for batch in response.data}
|
||||
expected_ids = {batch.id for batch in batches}
|
||||
assert batch_ids.issubset(expected_ids)
|
||||
|
||||
async def test_list_batches_with_pagination(self, provider):
|
||||
"""Test listing batches with pagination using 'after' parameter."""
|
||||
for i in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
# Get first page
|
||||
first_page = await provider.list_batches(limit=1)
|
||||
assert len(first_page.data) == 1
|
||||
assert first_page.has_more is True
|
||||
|
||||
# Get second page using 'after'
|
||||
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
|
||||
assert len(second_page.data) == 1
|
||||
assert second_page.data[0].id != first_page.data[0].id
|
||||
|
||||
# Verify we got the next batch in order
|
||||
all_batches = await provider.list_batches()
|
||||
expected_second_batch_id = all_batches.data[1].id
|
||||
assert second_page.data[0].id == expected_second_batch_id
|
||||
|
||||
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
|
||||
"""Test listing batches with invalid 'after' parameter."""
|
||||
await provider.create_batch(**sample_batch_data)
|
||||
|
||||
response = await provider.list_batches(after="nonexistent_batch")
|
||||
|
||||
# Should return all batches (no filtering when 'after' batch not found)
|
||||
assert len(response.data) == 1
|
||||
|
||||
async def test_kvstore_persistence(self, provider, sample_batch_data):
|
||||
"""Test that batches are properly persisted in kvstore."""
|
||||
batch = await provider.create_batch(**sample_batch_data)
|
||||
|
||||
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
|
||||
assert stored_data is not None
|
||||
|
||||
stored_batch_dict = json.loads(stored_data)
|
||||
assert stored_batch_dict["id"] == batch.id
|
||||
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
|
||||
|
||||
async def test_validate_input_file_not_found(self, provider):
|
||||
"""Test _validate_input when input file does not exist."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="nonexistent_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].message == "Cannot find file nonexistent_file."
|
||||
assert errors[0].param == "input_file_id"
|
||||
assert errors[0].line is None
|
||||
|
||||
async def test_validate_input_file_exists_empty_content(self, provider):
|
||||
"""Test _validate_input when file exists but is empty."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b""
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="empty_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 0
|
||||
assert len(requests) == 0
|
||||
|
||||
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
|
||||
"""Test _validate_input when file contains valid and invalid JSON lines."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="mixed_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 1
|
||||
|
||||
assert errors[0].code == "invalid_json_line"
|
||||
assert errors[0].line == 2
|
||||
assert errors[0].message == "This line is not parseable as valid JSON."
|
||||
|
||||
assert requests[0].custom_id == "req-1"
|
||||
assert requests[0].method == "POST"
|
||||
assert requests[0].url == "/v1/chat/completions"
|
||||
assert requests[0].body["model"] == "test-model"
|
||||
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
|
||||
|
||||
async def test_validate_input_invalid_model(self, provider):
|
||||
"""Test _validate_input when file contains request with non-existent model."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_model_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "model_not_found"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
|
||||
assert errors[0].param == "body.model"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,error_code,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
||||
"""Test _validate_input when file contains request with missing required parameters."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Remove the specific parameter being tested
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
del base_request[top_level][nested_param]
|
||||
else:
|
||||
del base_request[param_name]
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"missing_{param_name}_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == error_code
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_validate_input_url_mismatch(self, provider):
|
||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
|
||||
input_file_id="url_mismatch_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_url"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
|
||||
assert errors[0].param == "url"
|
||||
|
||||
async def test_validate_input_multiple_errors_per_request(self, provider):
|
||||
"""Test _validate_input when a single request has multiple validation errors."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
# Request missing custom_id, has invalid URL, and missing model in body
|
||||
mock_response.body = (
|
||||
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
|
||||
)
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
|
||||
input_file_id="multiple_errors_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
|
||||
assert len(requests) == 0
|
||||
|
||||
for error in errors:
|
||||
assert error.line == 1
|
||||
|
||||
error_codes = {error.code for error in errors}
|
||||
assert "missing_required_parameter" in error_codes # missing custom_id
|
||||
assert "invalid_url" in error_codes # URL mismatch
|
||||
|
||||
async def test_validate_input_invalid_request_format(self, provider):
|
||||
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.body = b'["not", "a", "request", "object"]'
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id="invalid_format_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == "Each line must be a JSON dictionary object"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"param_name,param_path,invalid_value,error_message",
|
||||
[
|
||||
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
|
||||
("url", "url", 123, "URL must be a string"),
|
||||
("method", "method", ["POST"], "Method must be a string"),
|
||||
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
|
||||
("model", "body.model", 123, "Model must be a string"),
|
||||
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
|
||||
],
|
||||
)
|
||||
async def test_validate_input_invalid_parameter_types(
|
||||
self, provider, param_name, param_path, invalid_value, error_message
|
||||
):
|
||||
"""Test _validate_input when file contains request with parameters that have invalid types."""
|
||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||
mock_response = MagicMock()
|
||||
|
||||
base_request = {
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
|
||||
}
|
||||
|
||||
# Override the specific parameter with invalid value
|
||||
if "." in param_path:
|
||||
top_level, nested_param = param_path.split(".", 1)
|
||||
base_request[top_level][nested_param] = invalid_value
|
||||
else:
|
||||
base_request[param_name] = invalid_value
|
||||
|
||||
mock_response.body = json.dumps(base_request).encode()
|
||||
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||
|
||||
batch = BatchObject(
|
||||
id="batch_test",
|
||||
object="batch",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=f"invalid_{param_name}_type_file",
|
||||
completion_window="24h",
|
||||
status="validating",
|
||||
created_at=1234567890,
|
||||
)
|
||||
|
||||
errors, requests = await provider._validate_input(batch)
|
||||
|
||||
assert len(errors) == 1
|
||||
assert len(requests) == 0
|
||||
|
||||
assert errors[0].code == "invalid_request"
|
||||
assert errors[0].line == 1
|
||||
assert errors[0].message == error_message
|
||||
assert errors[0].param == param_path
|
||||
|
||||
async def test_max_concurrent_batches(self, provider):
|
||||
"""Test max_concurrent_batches configuration and concurrency control."""
|
||||
import asyncio
|
||||
|
||||
provider._batch_semaphore = asyncio.Semaphore(2)
|
||||
|
||||
provider.process_batches = True # enable because we're testing background processing
|
||||
|
||||
active_batches = 0
|
||||
|
||||
async def add_and_wait(batch_id: str):
|
||||
nonlocal active_batches
|
||||
active_batches += 1
|
||||
await asyncio.sleep(float("inf"))
|
||||
|
||||
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
|
||||
# so we can replace _process_batch_impl with our mock to control concurrency
|
||||
provider._process_batch_impl = add_and_wait
|
||||
|
||||
for _ in range(3):
|
||||
await provider.create_batch(
|
||||
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.042) # let tasks start
|
||||
|
||||
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
|
|||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
convert_message_to_openai_dict_new,
|
||||
openai_messages_to_messages,
|
||||
)
|
||||
|
||||
|
@ -182,3 +183,42 @@ def test_user_message_accepts_images():
|
|||
assert len(msg.content) == 2
|
||||
assert msg.content[0].text == "Describe this image:"
|
||||
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_user_message():
|
||||
"""Test convert_message_to_openai_dict_new with UserMessage."""
|
||||
message = UserMessage(content="Hello, world!", role="user")
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
assert result["role"] == "user"
|
||||
assert result["content"] == "Hello, world!"
|
||||
|
||||
|
||||
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
|
||||
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
|
||||
message = CompletionMessage(
|
||||
content="I'll help you find the weather.",
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
call_id="call_123",
|
||||
tool_name="get_weather",
|
||||
arguments={"city": "Sligo"},
|
||||
arguments_json='{"city": "Sligo"}',
|
||||
)
|
||||
],
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
result = await convert_message_to_openai_dict_new(message)
|
||||
|
||||
# This would have failed with "Cannot instantiate typing.Union" before the fix
|
||||
assert result["role"] == "assistant"
|
||||
assert result["content"] == "I'll help you find the weather."
|
||||
assert "tool_calls" in result
|
||||
assert result["tool_calls"] is not None
|
||||
assert len(result["tool_calls"]) == 1
|
||||
|
||||
tool_call = result["tool_calls"][0]
|
||||
assert tool_call.id == "call_123"
|
||||
assert tool_call.type == "function"
|
||||
assert tool_call.function.name == "get_weather"
|
||||
assert tool_call.function.arguments == '{"city": "Sligo"}'
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue