Merge branch 'main' into chroma

This commit is contained in:
Bwook (Byoungwook) Kim 2025-08-18 16:11:36 +09:00 committed by GitHub
commit c66ebae9b6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
207 changed files with 15490 additions and 7927 deletions

View file

@ -0,0 +1,347 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from datetime import UTC, datetime
from unittest.mock import AsyncMock, patch
import pytest
from llama_stack.apis.agents import Session
from llama_stack.core.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import (
AgentPersistence,
AgentSessionInfo,
)
from llama_stack.providers.utils.kvstore import KVStore
@pytest.fixture
def mock_kvstore():
return AsyncMock(spec=KVStore)
@pytest.fixture
def mock_policy():
return []
@pytest.fixture
def agent_persistence(mock_kvstore, mock_policy):
return AgentPersistence(agent_id="test-agent-123", kvstore=mock_kvstore, policy=mock_policy)
@pytest.fixture
def sample_session():
return AgentSessionInfo(
session_id="session-123",
session_name="Test Session",
started_at=datetime.now(UTC),
owner=User(principal="user-123", attributes=None),
turns=[],
identifier="test-session",
type="session",
)
@pytest.fixture
def sample_session_json(sample_session):
return sample_session.model_dump_json()
class TestAgentPersistenceListSessions:
def setup_mock_kvstore(self, mock_kvstore, session_keys=None, turn_keys=None, invalid_keys=None, custom_data=None):
"""Helper to setup mock kvstore with sessions, turns, and custom/invalid data
Args:
mock_kvstore: The mock KVStore object
session_keys: List of session keys or dict mapping keys to custom session data
turn_keys: List of turn keys or dict mapping keys to custom turn data
invalid_keys: Dict mapping keys to invalid/corrupt data
custom_data: Additional custom data to add to the mock responses
"""
all_keys = []
mock_data = {}
# session keys
if session_keys:
if isinstance(session_keys, dict):
all_keys.extend(session_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in session_keys.items()})
else:
all_keys.extend(session_keys)
for key in session_keys:
session_id = key.split(":")[-1]
mock_data[key] = json.dumps(
{
"session_id": session_id,
"session_name": f"Session {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
)
# turn keys
if turn_keys:
if isinstance(turn_keys, dict):
all_keys.extend(turn_keys.keys())
mock_data.update({k: json.dumps(v) if isinstance(v, dict) else v for k, v in turn_keys.items()})
else:
all_keys.extend(turn_keys)
for key in turn_keys:
parts = key.split(":")
session_id = parts[-2]
turn_id = parts[-1]
mock_data[key] = json.dumps(
{
"turn_id": turn_id,
"session_id": session_id,
"input_messages": [],
"started_at": datetime.now(UTC).isoformat(),
}
)
if invalid_keys:
all_keys.extend(invalid_keys.keys())
mock_data.update(invalid_keys)
if custom_data:
mock_data.update(custom_data)
values_list = list(mock_data.values())
mock_kvstore.values_in_range.return_value = values_list
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
return mock_data
@pytest.mark.parametrize(
"scenario",
[
{
# from this issue: https://github.com/meta-llama/llama-stack/issues/3048
"name": "reported_bug",
"session_keys": ["session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
"turn_keys": [
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad"
],
"expected_sessions": ["1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"],
},
{
"name": "basic_filtering",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": ["session:test-agent-123:session-1:turn-1", "session:test-agent-123:session-1:turn-2"],
"expected_sessions": ["session-1", "session-2"],
},
{
"name": "multiple_turns_per_session",
"session_keys": ["session:test-agent-123:session-456"],
"turn_keys": [
"session:test-agent-123:session-456:turn-789",
"session:test-agent-123:session-456:turn-790",
],
"expected_sessions": ["session-456"],
},
{
"name": "multiple_sessions_with_turns",
"session_keys": ["session:test-agent-123:session-1", "session:test-agent-123:session-2"],
"turn_keys": [
"session:test-agent-123:session-1:turn-1",
"session:test-agent-123:session-1:turn-2",
"session:test-agent-123:session-2:turn-3",
],
"expected_sessions": ["session-1", "session-2"],
},
],
)
async def test_list_sessions_key_filtering(self, agent_persistence, mock_kvstore, scenario):
self.setup_mock_kvstore(mock_kvstore, session_keys=scenario["session_keys"], turn_keys=scenario["turn_keys"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == len(scenario["expected_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in scenario["expected_sessions"]:
assert expected_id in session_ids
# no errors should be logged
mock_log.error.assert_not_called()
@pytest.mark.parametrize(
"error_scenario",
[
{
"name": "invalid_json",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {"session:test-agent-123:invalid-json": "corrupted-json-data{"},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "missing_fields",
"valid_keys": ["session:test-agent-123:valid-session"],
"invalid_data": {
"session:test-agent-123:invalid-schema": json.dumps(
{
"session_id": "invalid-schema",
"session_name": "Missing Fields",
# missing `started_at` and `turns`
}
)
},
"expected_valid_sessions": ["valid-session"],
"expected_error_count": 1,
},
{
"name": "multiple_invalid",
"valid_keys": ["session:test-agent-123:valid-session-1", "session:test-agent-123:valid-session-2"],
"invalid_data": {
"session:test-agent-123:corrupted-json": "not-valid-json{",
"session:test-agent-123:incomplete-data": json.dumps({"incomplete": "data"}),
},
"expected_valid_sessions": ["valid-session-1", "valid-session-2"],
"expected_error_count": 2,
},
],
)
async def test_list_sessions_error_handling(self, agent_persistence, mock_kvstore, error_scenario):
session_keys = {}
for key in error_scenario["valid_keys"]:
session_id = key.split(":")[-1]
session_keys[key] = {
"session_id": session_id,
"session_name": f"Valid {session_id}",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys=session_keys, invalid_keys=error_scenario["invalid_data"])
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
# only valid sessions should be returned
assert len(result) == len(error_scenario["expected_valid_sessions"])
session_ids = {s.session_id for s in result}
for expected_id in error_scenario["expected_valid_sessions"]:
assert expected_id in session_ids
# error should be logged
assert mock_log.error.call_count > 0
assert mock_log.error.call_count == error_scenario["expected_error_count"]
async def test_list_sessions_empty(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
result = await agent_persistence.list_sessions()
assert result == []
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)
async def test_list_sessions_properties(self, agent_persistence, mock_kvstore):
session_data = {
"session_id": "session-123",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"owner": {"principal": "user-123", "attributes": None},
"turns": [],
}
self.setup_mock_kvstore(mock_kvstore, session_keys={"session:test-agent-123:session-123": session_data})
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert isinstance(result[0], Session)
assert result[0].session_id == "session-123"
assert result[0].session_name == "Test Session"
assert result[0].turns == []
assert hasattr(result[0], "started_at")
async def test_list_sessions_kvstore_exception(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.side_effect = Exception("KVStore error")
with pytest.raises(Exception, match="KVStore error"):
await agent_persistence.list_sessions()
async def test_bug_data_loss_with_real_data(self, agent_persistence, mock_kvstore):
# tests the handling of the issue reported in: https://github.com/meta-llama/llama-stack/issues/3048
session_data = {
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"session_name": "Test Session",
"started_at": datetime.now(UTC).isoformat(),
"turns": [],
}
turn_data = {
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"session_id": "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d",
"input_messages": [
{"role": "user", "content": "if i had a cluster i would want to call it persistence01", "context": None}
],
"steps": [
{
"turn_id": "eb7e818f-41fb-49a0-bdd6-464974a2d2ad",
"step_id": "c0f797dd-3d34-4bc5-a8f4-db6af9455132",
"started_at": "2025-08-05T14:31:50.000484Z",
"completed_at": "2025-08-05T14:31:51.303691Z",
"step_type": "inference",
"model_response": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
}
],
"output_message": {
"role": "assistant",
"content": "OK, I can create a cluster named 'persistence01' for you.",
"stop_reason": "end_of_turn",
"tool_calls": [],
},
"output_attachments": [],
"started_at": "2025-08-05T14:31:49.999950Z",
"completed_at": "2025-08-05T14:31:51.305384Z",
}
mock_data = {
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d": json.dumps(session_data),
"session:test-agent-123:1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d:eb7e818f-41fb-49a0-bdd6-464974a2d2ad": json.dumps(
turn_data
),
}
mock_kvstore.values_in_range.return_value = list(mock_data.values())
async def mock_get(key):
return mock_data.get(key)
mock_kvstore.get.side_effect = mock_get
with patch("llama_stack.providers.inline.agents.meta_reference.persistence.log") as mock_log:
result = await agent_persistence.list_sessions()
assert len(result) == 1
assert result[0].session_id == "1f08fd1c-5a9d-459d-a00b-36d4dfa49b7d"
# confirm no errors logged
mock_log.error.assert_not_called()
async def test_list_sessions_key_range_construction(self, agent_persistence, mock_kvstore):
mock_kvstore.values_in_range.return_value = []
await agent_persistence.list_sessions()
mock_kvstore.values_in_range.assert_called_once_with(
start_key="session:test-agent-123:", end_key="session:test-agent-123:\xff\xff\xff\xff"
)

View file

@ -41,7 +41,7 @@ from llama_stack.apis.inference import (
)
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
from llama_stack.core.access_control.access_control import default_policy
from llama_stack.providers.inline.agents.meta_reference.openai_responses import (
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
@ -136,9 +136,12 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
input=input_text,
model=model,
temperature=0.1,
stream=True, # Enable streaming to test content part events
)
# Verify
# For streaming response, collect all chunks
chunks = [chunk async for chunk in result]
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
@ -147,11 +150,32 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
stream=True,
temperature=0.1,
)
# Should have content part events for text streaming
# Expected: response.created, content_part.added, output_text.delta, content_part.done, response.completed
assert len(chunks) >= 4
assert chunks[0].type == "response.created"
# Check for content part events
content_part_added_events = [c for c in chunks if c.type == "response.content_part.added"]
content_part_done_events = [c for c in chunks if c.type == "response.content_part.done"]
text_delta_events = [c for c in chunks if c.type == "response.output_text.delta"]
assert len(content_part_added_events) >= 1, "Should have content_part.added event for text"
assert len(content_part_done_events) >= 1, "Should have content_part.done event for text"
assert len(text_delta_events) >= 1, "Should have text delta events"
# Verify final event is completion
assert chunks[-1].type == "response.completed"
# When streaming, the final response is in the last chunk
final_response = chunks[-1].response
assert final_response.model == model
assert len(final_response.output) == 1
assert isinstance(final_response.output[0], OpenAIResponseMessage)
openai_responses_impl.responses_store.store_response_object.assert_called_once()
assert result.model == model
assert len(result.output) == 1
assert isinstance(result.output[0], OpenAIResponseMessage)
assert result.output[0].content[0].text == "Dublin"
assert final_response.output[0].content[0].text == "Dublin"
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
@ -272,7 +296,11 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Verify event types
# Should have: response.created, output_item.added, function_call_arguments.delta,
# function_call_arguments.done, output_item.done, response.completed
assert len(chunks) == 6
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
@ -284,11 +312,17 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
# Check streaming events
assert chunks[1].type == "response.output_item.added"
assert chunks[2].type == "response.function_call_arguments.delta"
assert chunks[3].type == "response.function_call_arguments.done"
assert chunks[4].type == "response.output_item.done"
# Check response.completed event (should have the tool call)
assert chunks[1].type == "response.completed"
assert len(chunks[1].response.output) == 1
assert chunks[1].response.output[0].type == "function_call"
assert chunks[1].response.output[0].name == "get_weather"
assert chunks[5].type == "response.completed"
assert len(chunks[5].response.output) == 1
assert chunks[5].response.output[0].type == "function_call"
assert chunks[5].response.output[0].name == "get_weather"
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):

View file

@ -0,0 +1,310 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.agents.openai_responses import (
OpenAIResponseInputFunctionToolCallOutput,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolWebSearch,
OpenAIResponseMessage,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIChoice,
OpenAIDeveloperMessageParam,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatText,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
convert_chat_choice_to_response_message,
convert_response_content_to_chat_content,
convert_response_input_to_chat_messages,
convert_response_text_to_chat_response_format,
get_message_type_by_role,
is_function_tool_call,
)
class TestConvertChatChoiceToResponseMessage:
@pytest.mark.asyncio
async def test_convert_string_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(content="Test message"),
finish_reason="stop",
index=0,
)
result = await convert_chat_choice_to_response_message(choice)
assert result.role == "assistant"
assert result.status == "completed"
assert len(result.content) == 1
assert isinstance(result.content[0], OpenAIResponseOutputMessageContentOutputText)
assert result.content[0].text == "Test message"
@pytest.mark.asyncio
async def test_convert_text_param_content(self):
choice = OpenAIChoice(
message=OpenAIAssistantMessageParam(
content=[OpenAIChatCompletionContentPartTextParam(text="Test text param")]
),
finish_reason="stop",
index=0,
)
with pytest.raises(ValueError) as exc_info:
await convert_chat_choice_to_response_message(choice)
assert "does not yet support output content type" in str(exc_info.value)
class TestConvertResponseContentToChatContent:
@pytest.mark.asyncio
async def test_convert_string_content(self):
result = await convert_response_content_to_chat_content("Simple string")
assert result == "Simple string"
@pytest.mark.asyncio
async def test_convert_text_content_parts(self):
content = [
OpenAIResponseInputMessageContentText(text="First part"),
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 2
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
assert result[0].text == "First part"
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
assert result[1].text == "Second part"
@pytest.mark.asyncio
async def test_convert_image_content(self):
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
result = await convert_response_content_to_chat_content(content)
assert len(result) == 1
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
assert result[0].image_url.url == "https://example.com/image.jpg"
assert result[0].image_url.detail == "high"
class TestConvertResponseInputToChatMessages:
@pytest.mark.asyncio
async def test_convert_string_input(self):
result = await convert_response_input_to_chat_messages("User message")
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
assert result[0].content == "User message"
@pytest.mark.asyncio
async def test_convert_function_tool_call_output(self):
input_items = [
OpenAIResponseInputFunctionToolCallOutput(
output="Tool output",
call_id="call_123",
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIToolMessageParam)
assert result[0].content == "Tool output"
assert result[0].tool_call_id == "call_123"
@pytest.mark.asyncio
async def test_convert_function_tool_call(self):
input_items = [
OpenAIResponseOutputMessageFunctionToolCall(
call_id="call_456",
name="test_function",
arguments='{"param": "value"}',
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIAssistantMessageParam)
assert len(result[0].tool_calls) == 1
assert result[0].tool_calls[0].id == "call_456"
assert result[0].tool_calls[0].function.name == "test_function"
assert result[0].tool_calls[0].function.arguments == '{"param": "value"}'
@pytest.mark.asyncio
async def test_convert_response_message(self):
input_items = [
OpenAIResponseMessage(
role="user",
content=[OpenAIResponseInputMessageContentText(text="User text")],
)
]
result = await convert_response_input_to_chat_messages(input_items)
assert len(result) == 1
assert isinstance(result[0], OpenAIUserMessageParam)
# Content should be converted to chat content format
assert len(result[0].content) == 1
assert result[0].content[0].text == "User text"
class TestConvertResponseTextToChatResponseFormat:
@pytest.mark.asyncio
async def test_convert_text_format(self):
text = OpenAIResponseText(format=OpenAIResponseTextFormat(type="text"))
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
@pytest.mark.asyncio
async def test_convert_json_object_format(self):
text = OpenAIResponseText(format={"type": "json_object"})
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONObject)
@pytest.mark.asyncio
async def test_convert_json_schema_format(self):
schema_def = {"type": "object", "properties": {"test": {"type": "string"}}}
text = OpenAIResponseText(
format={
"type": "json_schema",
"name": "test_schema",
"schema": schema_def,
}
)
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatJSONSchema)
assert result.json_schema["name"] == "test_schema"
assert result.json_schema["schema"] == schema_def
@pytest.mark.asyncio
async def test_default_text_format(self):
text = OpenAIResponseText()
result = await convert_response_text_to_chat_response_format(text)
assert isinstance(result, OpenAIResponseFormatText)
assert result.type == "text"
class TestGetMessageTypeByRole:
@pytest.mark.asyncio
async def test_user_role(self):
result = await get_message_type_by_role("user")
assert result == OpenAIUserMessageParam
@pytest.mark.asyncio
async def test_system_role(self):
result = await get_message_type_by_role("system")
assert result == OpenAISystemMessageParam
@pytest.mark.asyncio
async def test_assistant_role(self):
result = await get_message_type_by_role("assistant")
assert result == OpenAIAssistantMessageParam
@pytest.mark.asyncio
async def test_developer_role(self):
result = await get_message_type_by_role("developer")
assert result == OpenAIDeveloperMessageParam
@pytest.mark.asyncio
async def test_unknown_role(self):
result = await get_message_type_by_role("unknown")
assert result is None
class TestIsFunctionToolCall:
def test_is_function_tool_call_true(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="test_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is True
def test_is_function_tool_call_false_different_name(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="other_function",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_no_function(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=None,
)
tools = [
OpenAIResponseInputToolFunction(
type="function", name="test_function", parameters={"type": "object", "properties": {}}
),
]
result = is_function_tool_call(tool_call, tools)
assert result is False
def test_is_function_tool_call_false_wrong_type(self):
tool_call = OpenAIChatCompletionToolCall(
index=0,
id="call_123",
function=OpenAIChatCompletionToolCallFunction(
name="web_search",
arguments="{}",
),
)
tools = [
OpenAIResponseInputToolWebSearch(type="web_search"),
]
result = is_function_tool_call(tool_call, tools)
assert result is False

View file

@ -0,0 +1,753 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Test suite for the reference implementation of the Batches API.
The tests are categorized and outlined below, keep this updated:
- Batch creation with various parameters and validation:
* test_create_and_retrieve_batch_success (positive)
* test_create_batch_without_metadata (positive)
* test_create_batch_completion_window (negative)
* test_create_batch_invalid_endpoints (negative)
* test_create_batch_invalid_metadata (negative)
- Batch retrieval and error handling for non-existent batches:
* test_retrieve_batch_not_found (negative)
- Batch cancellation with proper status transitions:
* test_cancel_batch_success (positive)
* test_cancel_batch_invalid_statuses (negative)
* test_cancel_batch_not_found (negative)
- Batch listing with pagination and filtering:
* test_list_batches_empty (positive)
* test_list_batches_single_batch (positive)
* test_list_batches_multiple_batches (positive)
* test_list_batches_with_limit (positive)
* test_list_batches_with_pagination (positive)
* test_list_batches_invalid_after (negative)
- Data persistence in the underlying key-value store:
* test_kvstore_persistence (positive)
- Batch processing concurrency control:
* test_max_concurrent_batches (positive)
- Input validation testing (direct _validate_input method tests):
* test_validate_input_file_not_found (negative)
* test_validate_input_file_exists_empty_content (positive)
* test_validate_input_file_mixed_valid_invalid_json (mixed)
* test_validate_input_invalid_model (negative)
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
dependencies like inference, files, and models APIs.
"""
import json
import tempfile
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from llama_stack.apis.batches import BatchObject
from llama_stack.apis.common.errors import ConflictError, ResourceNotFoundError
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class TestReferenceBatchesImpl:
"""Test the reference implementation of the Batches API."""
@pytest.fixture
async def provider(self):
"""Create a test provider instance with temporary database."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_batches.db"
kvstore_config = SqliteKVStoreConfig(db_path=str(db_path))
config = ReferenceBatchesImplConfig(kvstore=kvstore_config)
# Create kvstore and mock APIs
from unittest.mock import AsyncMock
from llama_stack.providers.utils.kvstore import kvstore_impl
kvstore = await kvstore_impl(config.kvstore)
mock_inference = AsyncMock()
mock_files = AsyncMock()
mock_models = AsyncMock()
provider = ReferenceBatchesImpl(config, mock_inference, mock_files, mock_models, kvstore)
await provider.initialize()
# unit tests should not require background processing
provider.process_batches = False
yield provider
await provider.shutdown()
@pytest.fixture
def sample_batch_data(self):
"""Sample batch data for testing."""
return {
"input_file_id": "file_abc123",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {"test": "true", "priority": "high"},
}
def _validate_batch_type(self, batch, expected_metadata=None):
"""
Helper function to validate batch object structure and field types.
Note: This validates the direct BatchObject from the provider, not the
client library response which has a different structure.
Args:
batch: The BatchObject instance to validate.
expected_metadata: Optional expected metadata dictionary to validate against.
"""
assert isinstance(batch.id, str)
assert isinstance(batch.completion_window, str)
assert isinstance(batch.created_at, int)
assert isinstance(batch.endpoint, str)
assert isinstance(batch.input_file_id, str)
assert batch.object == "batch"
assert batch.status in [
"validating",
"failed",
"in_progress",
"finalizing",
"completed",
"expired",
"cancelling",
"cancelled",
]
if expected_metadata is not None:
assert batch.metadata == expected_metadata
timestamp_fields = [
"cancelled_at",
"cancelling_at",
"completed_at",
"expired_at",
"expires_at",
"failed_at",
"finalizing_at",
"in_progress_at",
]
for field in timestamp_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, int), f"{field} should be int or None, got {type(field_value)}"
file_id_fields = ["error_file_id", "output_file_id"]
for field in file_id_fields:
field_value = getattr(batch, field, None)
if field_value is not None:
assert isinstance(field_value, str), f"{field} should be str or None, got {type(field_value)}"
if hasattr(batch, "request_counts") and batch.request_counts is not None:
assert isinstance(batch.request_counts.completed, int), (
f"request_counts.completed should be int, got {type(batch.request_counts.completed)}"
)
assert isinstance(batch.request_counts.failed, int), (
f"request_counts.failed should be int, got {type(batch.request_counts.failed)}"
)
assert isinstance(batch.request_counts.total, int), (
f"request_counts.total should be int, got {type(batch.request_counts.total)}"
)
if hasattr(batch, "errors") and batch.errors is not None:
assert isinstance(batch.errors, dict), f"errors should be object or dict, got {type(batch.errors)}"
if hasattr(batch.errors, "data") and batch.errors.data is not None:
assert isinstance(batch.errors.data, list), (
f"errors.data should be list or None, got {type(batch.errors.data)}"
)
for i, error_item in enumerate(batch.errors.data):
assert isinstance(error_item, dict), (
f"errors.data[{i}] should be object or dict, got {type(error_item)}"
)
if hasattr(error_item, "code") and error_item.code is not None:
assert isinstance(error_item.code, str), (
f"errors.data[{i}].code should be str or None, got {type(error_item.code)}"
)
if hasattr(error_item, "line") and error_item.line is not None:
assert isinstance(error_item.line, int), (
f"errors.data[{i}].line should be int or None, got {type(error_item.line)}"
)
if hasattr(error_item, "message") and error_item.message is not None:
assert isinstance(error_item.message, str), (
f"errors.data[{i}].message should be str or None, got {type(error_item.message)}"
)
if hasattr(error_item, "param") and error_item.param is not None:
assert isinstance(error_item.param, str), (
f"errors.data[{i}].param should be str or None, got {type(error_item.param)}"
)
if hasattr(batch.errors, "object") and batch.errors.object is not None:
assert isinstance(batch.errors.object, str), (
f"errors.object should be str or None, got {type(batch.errors.object)}"
)
assert batch.errors.object == "list", f"errors.object should be 'list', got {batch.errors.object}"
async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data):
"""Test successful batch creation and retrieval."""
created_batch = await provider.create_batch(**sample_batch_data)
self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"])
assert created_batch.id.startswith("batch_")
assert len(created_batch.id) > 13
assert created_batch.object == "batch"
assert created_batch.endpoint == sample_batch_data["endpoint"]
assert created_batch.input_file_id == sample_batch_data["input_file_id"]
assert created_batch.completion_window == sample_batch_data["completion_window"]
assert created_batch.status == "validating"
assert created_batch.metadata == sample_batch_data["metadata"]
assert isinstance(created_batch.created_at, int)
assert created_batch.created_at > 0
retrieved_batch = await provider.retrieve_batch(created_batch.id)
self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"])
assert retrieved_batch.id == created_batch.id
assert retrieved_batch.input_file_id == created_batch.input_file_id
assert retrieved_batch.endpoint == created_batch.endpoint
assert retrieved_batch.status == created_batch.status
assert retrieved_batch.metadata == created_batch.metadata
async def test_create_batch_without_metadata(self, provider):
"""Test batch creation without optional metadata."""
batch = await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h"
)
assert batch.metadata is None
async def test_create_batch_completion_window(self, provider):
"""Test batch creation with invalid completion window."""
with pytest.raises(ValueError, match="Invalid completion_window"):
await provider.create_batch(
input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now"
)
@pytest.mark.parametrize(
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
)
async def test_create_batch_invalid_endpoints(self, provider, endpoint):
"""Test batch creation with various invalid endpoints."""
with pytest.raises(ValueError, match="Invalid endpoint"):
await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h")
async def test_create_batch_invalid_metadata(self, provider):
"""Test that batch creation fails with invalid metadata."""
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={123: "invalid_key"}, # Non-string key
)
with pytest.raises(ValueError, match="should be a valid string"):
await provider.create_batch(
input_file_id="file_123",
endpoint="/v1/chat/completions",
completion_window="24h",
metadata={"valid_key": 456}, # Non-string value
)
async def test_retrieve_batch_not_found(self, provider):
"""Test error when retrieving non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.retrieve_batch("nonexistent_batch")
async def test_cancel_batch_success(self, provider, sample_batch_data):
"""Test successful batch cancellation."""
created_batch = await provider.create_batch(**sample_batch_data)
assert created_batch.status == "validating"
cancelled_batch = await provider.cancel_batch(created_batch.id)
assert cancelled_batch.id == created_batch.id
assert cancelled_batch.status in ["cancelling", "cancelled"]
assert isinstance(cancelled_batch.cancelling_at, int)
assert cancelled_batch.cancelling_at >= created_batch.created_at
@pytest.mark.parametrize("status", ["failed", "expired", "completed"])
async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status):
"""Test error when cancelling batch in final states."""
provider.process_batches = False
created_batch = await provider.create_batch(**sample_batch_data)
# directly update status in kvstore
await provider._update_batch(created_batch.id, status=status)
with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"):
await provider.cancel_batch(created_batch.id)
async def test_cancel_batch_not_found(self, provider):
"""Test error when cancelling non-existent batch."""
with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"):
await provider.cancel_batch("nonexistent_batch")
async def test_list_batches_empty(self, provider):
"""Test listing batches when none exist."""
response = await provider.list_batches()
assert response.object == "list"
assert response.data == []
assert response.first_id is None
assert response.last_id is None
assert response.has_more is False
async def test_list_batches_single_batch(self, provider, sample_batch_data):
"""Test listing batches with single batch."""
created_batch = await provider.create_batch(**sample_batch_data)
response = await provider.list_batches()
assert len(response.data) == 1
self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"])
assert response.data[0].id == created_batch.id
assert response.first_id == created_batch.id
assert response.last_id == created_batch.id
assert response.has_more is False
async def test_list_batches_multiple_batches(self, provider):
"""Test listing multiple batches."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches()
assert len(response.data) == 3
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids == expected_ids
assert response.has_more is False
assert response.first_id in expected_ids
assert response.last_id in expected_ids
async def test_list_batches_with_limit(self, provider):
"""Test listing batches with limit parameter."""
batches = [
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
for i in range(3)
]
response = await provider.list_batches(limit=2)
assert len(response.data) == 2
assert response.has_more is True
assert response.first_id == response.data[0].id
assert response.last_id == response.data[1].id
batch_ids = {batch.id for batch in response.data}
expected_ids = {batch.id for batch in batches}
assert batch_ids.issubset(expected_ids)
async def test_list_batches_with_pagination(self, provider):
"""Test listing batches with pagination using 'after' parameter."""
for i in range(3):
await provider.create_batch(
input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h"
)
# Get first page
first_page = await provider.list_batches(limit=1)
assert len(first_page.data) == 1
assert first_page.has_more is True
# Get second page using 'after'
second_page = await provider.list_batches(limit=1, after=first_page.data[0].id)
assert len(second_page.data) == 1
assert second_page.data[0].id != first_page.data[0].id
# Verify we got the next batch in order
all_batches = await provider.list_batches()
expected_second_batch_id = all_batches.data[1].id
assert second_page.data[0].id == expected_second_batch_id
async def test_list_batches_invalid_after(self, provider, sample_batch_data):
"""Test listing batches with invalid 'after' parameter."""
await provider.create_batch(**sample_batch_data)
response = await provider.list_batches(after="nonexistent_batch")
# Should return all batches (no filtering when 'after' batch not found)
assert len(response.data) == 1
async def test_kvstore_persistence(self, provider, sample_batch_data):
"""Test that batches are properly persisted in kvstore."""
batch = await provider.create_batch(**sample_batch_data)
stored_data = await provider.kvstore.get(f"batch:{batch.id}")
assert stored_data is not None
stored_batch_dict = json.loads(stored_data)
assert stored_batch_dict["id"] == batch.id
assert stored_batch_dict["input_file_id"] == sample_batch_data["input_file_id"]
async def test_validate_input_file_not_found(self, provider):
"""Test _validate_input when input file does not exist."""
provider.files_api.openai_retrieve_file = AsyncMock(side_effect=Exception("File not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="nonexistent_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].message == "Cannot find file nonexistent_file."
assert errors[0].param == "input_file_id"
assert errors[0].line is None
async def test_validate_input_file_exists_empty_content(self, provider):
"""Test _validate_input when file exists but is empty."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b""
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="empty_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 0
assert len(requests) == 0
async def test_validate_input_file_mixed_valid_invalid_json(self, provider):
"""Test _validate_input when file contains valid and invalid JSON lines."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Line 1: valid JSON with proper body args, Line 2: invalid JSON
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}\n{invalid json'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="mixed_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
# Should have 1 JSON parsing error from line 2, and 1 valid request from line 1
assert len(errors) == 1
assert len(requests) == 1
assert errors[0].code == "invalid_json_line"
assert errors[0].line == 2
assert errors[0].message == "This line is not parseable as valid JSON."
assert requests[0].custom_id == "req-1"
assert requests[0].method == "POST"
assert requests[0].url == "/v1/chat/completions"
assert requests[0].body["model"] == "test-model"
assert requests[0].body["messages"] == [{"role": "user", "content": "Hello"}]
async def test_validate_input_invalid_model(self, provider):
"""Test _validate_input when file contains request with non-existent model."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "nonexistent-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
provider.models_api.get_model = AsyncMock(side_effect=Exception("Model not found"))
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_model_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "model_not_found"
assert errors[0].line == 1
assert errors[0].message == "Model 'nonexistent-model' does not exist or is not supported"
assert errors[0].param == "body.model"
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'{"custom_id": "req-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]}}'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # This doesn't match the URL in the request
input_file_id="url_mismatch_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_url"
assert errors[0].line == 1
assert errors[0].message == "URL provided for this request does not match the batch endpoint"
assert errors[0].param == "url"
async def test_validate_input_multiple_errors_per_request(self, provider):
"""Test _validate_input when a single request has multiple validation errors."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
# Request missing custom_id, has invalid URL, and missing model in body
mock_response.body = (
b'{"method": "POST", "url": "/v1/embeddings", "body": {"messages": [{"role": "user", "content": "Hello"}]}}'
)
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions", # Doesn't match /v1/embeddings in request
input_file_id="multiple_errors_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) >= 2 # At least missing custom_id and URL mismatch
assert len(requests) == 0
for error in errors:
assert error.line == 1
error_codes = {error.code for error in errors}
assert "missing_required_parameter" in error_codes # missing custom_id
assert "invalid_url" in error_codes # URL mismatch
async def test_validate_input_invalid_request_format(self, provider):
"""Test _validate_input when file contains non-object JSON (array, string, number)."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
mock_response.body = b'["not", "a", "request", "object"]'
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id="invalid_format_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == "Each line must be a JSON dictionary object"
@pytest.mark.parametrize(
"param_name,param_path,invalid_value,error_message",
[
("custom_id", "custom_id", 12345, "Custom_id must be a string"),
("url", "url", 123, "URL must be a string"),
("method", "method", ["POST"], "Method must be a string"),
("body", "body", ["not", "valid"], "Body must be a JSON dictionary object"),
("model", "body.model", 123, "Model must be a string"),
("messages", "body.messages", "invalid messages format", "Messages must be an array"),
],
)
async def test_validate_input_invalid_parameter_types(
self, provider, param_name, param_path, invalid_value, error_message
):
"""Test _validate_input when file contains request with parameters that have invalid types."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {"model": "test-model", "messages": [{"role": "user", "content": "Hello"}]},
}
# Override the specific parameter with invalid value
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
base_request[top_level][nested_param] = invalid_value
else:
base_request[param_name] = invalid_value
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/chat/completions",
input_file_id=f"invalid_{param_name}_type_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == "invalid_request"
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_max_concurrent_batches(self, provider):
"""Test max_concurrent_batches configuration and concurrency control."""
import asyncio
provider._batch_semaphore = asyncio.Semaphore(2)
provider.process_batches = True # enable because we're testing background processing
active_batches = 0
async def add_and_wait(batch_id: str):
nonlocal active_batches
active_batches += 1
await asyncio.sleep(float("inf"))
# the first thing done in _process_batch is to acquire the semaphore, then call _process_batch_impl,
# so we can replace _process_batch_impl with our mock to control concurrency
provider._process_batch_impl = add_and_wait
for _ in range(3):
await provider.create_batch(
input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h"
)
await asyncio.sleep(0.042) # let tasks start
assert active_batches == 2, f"Expected 2 active batches, got {active_batches}"

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inference import (
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
convert_message_to_openai_dict_new,
openai_messages_to_messages,
)
@ -182,3 +183,42 @@ def test_user_message_accepts_images():
assert len(msg.content) == 2
assert msg.content[0].text == "Describe this image:"
assert msg.content[1].image_url.url == "http://example.com/image.jpg"
async def test_convert_message_to_openai_dict_new_user_message():
"""Test convert_message_to_openai_dict_new with UserMessage."""
message = UserMessage(content="Hello, world!", role="user")
result = await convert_message_to_openai_dict_new(message)
assert result["role"] == "user"
assert result["content"] == "Hello, world!"
async def test_convert_message_to_openai_dict_new_completion_message_with_tool_calls():
"""Test convert_message_to_openai_dict_new with CompletionMessage containing tool calls."""
message = CompletionMessage(
content="I'll help you find the weather.",
tool_calls=[
ToolCall(
call_id="call_123",
tool_name="get_weather",
arguments={"city": "Sligo"},
arguments_json='{"city": "Sligo"}',
)
],
stop_reason=StopReason.end_of_turn,
)
result = await convert_message_to_openai_dict_new(message)
# This would have failed with "Cannot instantiate typing.Union" before the fix
assert result["role"] == "assistant"
assert result["content"] == "I'll help you find the weather."
assert "tool_calls" in result
assert result["tool_calls"] is not None
assert len(result["tool_calls"]) == 1
tool_call = result["tool_calls"][0]
assert tool_call.id == "call_123"
assert tool_call.type == "function"
assert tool_call.function.name == "get_weather"
assert tool_call.function.arguments == '{"city": "Sligo"}'