mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into routeur
This commit is contained in:
commit
3770963130
255 changed files with 18366 additions and 1909 deletions
|
|
@ -17,6 +17,7 @@ from openai.types.chat.chat_completion_chunk import (
|
|||
|
||||
from llama_stack.core.access_control.access_control import default_policy
|
||||
from llama_stack.core.storage.datatypes import ResponsesStoreReference, SqliteSqlStoreConfig
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
|
@ -24,7 +25,13 @@ from llama_stack.providers.utils.responses.responses_store import (
|
|||
ResponsesStore,
|
||||
_OpenAIResponseObjectWithInputAndMessages,
|
||||
)
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack_api import (
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIFile,
|
||||
OpenAIFileObject,
|
||||
OpenAISystemMessageParam,
|
||||
Prompt,
|
||||
)
|
||||
from llama_stack_api.agents import Order
|
||||
from llama_stack_api.inference import (
|
||||
OpenAIAssistantMessageParam,
|
||||
|
|
@ -38,6 +45,8 @@ from llama_stack_api.inference import (
|
|||
)
|
||||
from llama_stack_api.openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
OpenAIResponseInputMessageContentFile,
|
||||
OpenAIResponseInputMessageContentImage,
|
||||
OpenAIResponseInputMessageContentText,
|
||||
OpenAIResponseInputToolFunction,
|
||||
OpenAIResponseInputToolMCP,
|
||||
|
|
@ -47,6 +56,7 @@ from llama_stack_api.openai_responses import (
|
|||
OpenAIResponseOutputMessageFunctionToolCall,
|
||||
OpenAIResponseOutputMessageMCPCall,
|
||||
OpenAIResponseOutputMessageWebSearchToolCall,
|
||||
OpenAIResponsePrompt,
|
||||
OpenAIResponseText,
|
||||
OpenAIResponseTextFormat,
|
||||
WebSearchToolTypes,
|
||||
|
|
@ -98,6 +108,19 @@ def mock_safety_api():
|
|||
return safety_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_prompts_api():
|
||||
prompts_api = AsyncMock()
|
||||
return prompts_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_files_api():
|
||||
"""Mock files API for testing."""
|
||||
files_api = AsyncMock()
|
||||
return files_api
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def openai_responses_impl(
|
||||
mock_inference_api,
|
||||
|
|
@ -107,6 +130,8 @@ def openai_responses_impl(
|
|||
mock_vector_io_api,
|
||||
mock_safety_api,
|
||||
mock_conversations_api,
|
||||
mock_prompts_api,
|
||||
mock_files_api,
|
||||
):
|
||||
return OpenAIResponsesImpl(
|
||||
inference_api=mock_inference_api,
|
||||
|
|
@ -116,6 +141,8 @@ def openai_responses_impl(
|
|||
vector_io_api=mock_vector_io_api,
|
||||
safety_api=mock_safety_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
prompts_api=mock_prompts_api,
|
||||
files_api=mock_files_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -499,7 +526,7 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
|
|||
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
|
||||
|
||||
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
|
||||
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api, mock_files_api):
|
||||
"""Test creating an OpenAI response with multiple messages."""
|
||||
# Setup
|
||||
input_messages = [
|
||||
|
|
@ -710,7 +737,7 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
|
|||
|
||||
|
||||
async def test_create_openai_response_with_instructions_and_multiple_messages(
|
||||
openai_responses_impl, mock_inference_api
|
||||
openai_responses_impl, mock_inference_api, mock_files_api
|
||||
):
|
||||
# Setup
|
||||
input_messages = [
|
||||
|
|
@ -1242,3 +1269,489 @@ async def test_create_openai_response_with_output_types_as_input(
|
|||
|
||||
assert stored_with_outputs.input == input_with_output_types
|
||||
assert len(stored_with_outputs.input) == 3
|
||||
|
||||
|
||||
async def test_create_openai_response_with_prompt(openai_responses_impl, mock_inference_api, mock_prompts_api):
|
||||
"""Test creating an OpenAI response with a prompt."""
|
||||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="You are a helpful {{ area_name }} assistant at {{ company_name }}. Always provide accurate information.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["area_name", "company_name"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"area_name": OpenAIResponseInputMessageContentText(text="geography"),
|
||||
"company_name": OpenAIResponseInputMessageContentText(text="Dummy Company"),
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
result = await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
prompt=openai_response_prompt,
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
|
||||
mock_inference_api.openai_chat_completion.assert_called()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.args[0].messages
|
||||
assert len(sent_messages) == 2
|
||||
|
||||
system_messages = [msg for msg in sent_messages if msg.role == "system"]
|
||||
assert len(system_messages) == 1
|
||||
assert (
|
||||
system_messages[0].content
|
||||
== "You are a helpful geography assistant at Dummy Company. Always provide accurate information."
|
||||
)
|
||||
|
||||
user_messages = [msg for msg in sent_messages if msg.role == "user"]
|
||||
assert len(user_messages) == 1
|
||||
assert user_messages[0].content == input_text
|
||||
|
||||
assert result.model == model
|
||||
assert result.status == "completed"
|
||||
assert isinstance(result.prompt, OpenAIResponsePrompt)
|
||||
assert result.prompt.id == prompt_id
|
||||
assert result.prompt.variables == openai_response_prompt.variables
|
||||
assert result.prompt.version == "1"
|
||||
|
||||
|
||||
async def test_prepend_prompt_successful_without_variables(openai_responses_impl, mock_prompts_api, mock_inference_api):
|
||||
"""Test prepend_prompt function without variables."""
|
||||
input_text = "What is the capital of Ireland?"
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="You are a helpful assistant. Always provide accurate information.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=[],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
mock_inference_api.openai_chat_completion.return_value = fake_stream()
|
||||
|
||||
await openai_responses_impl.create_openai_response(
|
||||
input=input_text,
|
||||
model=model,
|
||||
prompt=openai_response_prompt,
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
|
||||
mock_inference_api.openai_chat_completion.assert_called()
|
||||
call_args = mock_inference_api.openai_chat_completion.call_args
|
||||
sent_messages = call_args.args[0].messages
|
||||
assert len(sent_messages) == 2
|
||||
system_messages = [msg for msg in sent_messages if msg.role == "system"]
|
||||
assert system_messages[0].content == "You are a helpful assistant. Always provide accurate information."
|
||||
|
||||
|
||||
async def test_prepend_prompt_invalid_variable(openai_responses_impl, mock_prompts_api):
|
||||
"""Test error handling in prepend_prompt function when prompt parameters contain invalid variables."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="You are a {{ role }} assistant.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["role"], # Only "role" is valid
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"role": OpenAIResponseInputMessageContentText(text="helpful"),
|
||||
"company": OpenAIResponseInputMessageContentText(
|
||||
text="Dummy Company"
|
||||
), # company is not in prompt.variables
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Test prompt")]
|
||||
|
||||
# Execute - should raise ValueError for invalid variable
|
||||
with pytest.raises(ValueError, match="Variable company not found in prompt"):
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
# Verify
|
||||
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
|
||||
|
||||
|
||||
async def test_prepend_prompt_not_found(openai_responses_impl, mock_prompts_api):
|
||||
"""Test prepend_prompt function when prompt is not found."""
|
||||
prompt_id = "pmpt_nonexistent"
|
||||
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = None # Prompt not found
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Test prompt")]
|
||||
initial_length = len(messages)
|
||||
|
||||
# Execute
|
||||
result = await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
# Verify
|
||||
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
|
||||
|
||||
# Should return None when prompt not found
|
||||
assert result is None
|
||||
|
||||
# Messages should not be modified
|
||||
assert len(messages) == initial_length
|
||||
assert messages[0].content == "Test prompt"
|
||||
|
||||
|
||||
async def test_prepend_prompt_variable_substitution(openai_responses_impl, mock_prompts_api):
|
||||
"""Test complex variable substitution with multiple occurrences and special characters in prepend_prompt function."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
|
||||
# Support all whitespace variations: {{name}}, {{ name }}, {{ name}}, {{name }}, etc.
|
||||
prompt = Prompt(
|
||||
prompt="Hello {{name}}! You are working at {{ company}}. Your role is {{role}} at {{company}}. Remember, {{ name }}, to be {{ tone }}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["name", "company", "role", "tone"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"name": OpenAIResponseInputMessageContentText(text="Alice"),
|
||||
"company": OpenAIResponseInputMessageContentText(text="Dummy Company"),
|
||||
"role": OpenAIResponseInputMessageContentText(text="AI Assistant"),
|
||||
"tone": OpenAIResponseInputMessageContentText(text="professional"),
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Test")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
# Verify
|
||||
assert len(messages) == 2
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
expected_content = "Hello Alice! You are working at Dummy Company. Your role is AI Assistant at Dummy Company. Remember, Alice, to be professional."
|
||||
assert messages[0].content == expected_content
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_image_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
|
||||
"""Test prepend_prompt with image variable - should create placeholder in system message and append image as separate user message."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Analyze this {{product_image}} and describe what you see.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["product_image"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Mock file content and file metadata
|
||||
mock_file_content = b"fake_image_data"
|
||||
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
|
||||
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-abc123",
|
||||
bytes=len(mock_file_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="product.jpg",
|
||||
purpose="assistants",
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"product_image": OpenAIResponseInputMessageContentImage(
|
||||
file_id="file-abc123",
|
||||
detail="high",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="What do you think?")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has placeholder
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
assert messages[0].content == "Analyze this [Image: product_image] and describe what you see."
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "What do you think?"
|
||||
|
||||
# Check new user message with image is appended
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
assert len(messages[2].content) == 1
|
||||
|
||||
# Should be image with data URL
|
||||
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert messages[2].content[0].image_url.url.startswith("data:image/")
|
||||
assert messages[2].content[0].image_url.detail == "high"
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_file_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
|
||||
"""Test prepend_prompt with file variable - should create placeholder in system message and append file as separate user message."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Review the document {{contract_file}} and summarize key points.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["contract_file"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Mock file retrieval
|
||||
mock_file_content = b"fake_pdf_content"
|
||||
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
|
||||
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-contract-789",
|
||||
bytes=len(mock_file_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="contract.pdf",
|
||||
purpose="assistants",
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"contract_file": OpenAIResponseInputMessageContentFile(
|
||||
file_id="file-contract-789",
|
||||
filename="contract.pdf",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Please review this.")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has placeholder
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
assert messages[0].content == "Review the document [File: contract_file] and summarize key points."
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "Please review this."
|
||||
|
||||
# Check new user message with file is appended
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
assert len(messages[2].content) == 1
|
||||
|
||||
# First part should be file with data URL
|
||||
assert isinstance(messages[2].content[0], OpenAIFile)
|
||||
assert messages[2].content[0].file.file_data.startswith("data:application/pdf;base64,")
|
||||
assert messages[2].content[0].file.filename == "contract.pdf"
|
||||
assert messages[2].content[0].file.file_id is None
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_mixed_variables(openai_responses_impl, mock_prompts_api, mock_files_api):
|
||||
"""Test prepend_prompt with text, image, and file variables mixed together."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Hello {{name}}! Analyze {{photo}} and review {{document}}. Provide insights for {{company}}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["name", "photo", "document", "company"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Mock file retrieval for image and file
|
||||
mock_image_content = b"fake_image_data"
|
||||
mock_file_content = b"fake_doc_content"
|
||||
|
||||
async def mock_retrieve_file_content(file_id):
|
||||
if file_id == "file-photo-123":
|
||||
return type("obj", (object,), {"body": mock_image_content})()
|
||||
elif file_id == "file-doc-456":
|
||||
return type("obj", (object,), {"body": mock_file_content})()
|
||||
|
||||
mock_files_api.openai_retrieve_file_content.side_effect = mock_retrieve_file_content
|
||||
|
||||
def mock_retrieve_file(file_id):
|
||||
if file_id == "file-photo-123":
|
||||
return OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-photo-123",
|
||||
bytes=len(mock_image_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="photo.jpg",
|
||||
purpose="assistants",
|
||||
)
|
||||
elif file_id == "file-doc-456":
|
||||
return OpenAIFileObject(
|
||||
object="file",
|
||||
id="file-doc-456",
|
||||
bytes=len(mock_file_content),
|
||||
created_at=1234567890,
|
||||
expires_at=1234567890,
|
||||
filename="doc.pdf",
|
||||
purpose="assistants",
|
||||
)
|
||||
|
||||
mock_files_api.openai_retrieve_file.side_effect = mock_retrieve_file
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"name": OpenAIResponseInputMessageContentText(text="Alice"),
|
||||
"photo": OpenAIResponseInputMessageContentImage(file_id="file-photo-123", detail="auto"),
|
||||
"document": OpenAIResponseInputMessageContentFile(file_id="file-doc-456", filename="doc.pdf"),
|
||||
"company": OpenAIResponseInputMessageContentText(text="Acme Corp"),
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="Here's my question.")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has text and placeholders
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
expected_system = "Hello Alice! Analyze [Image: photo] and review [File: document]. Provide insights for Acme Corp."
|
||||
assert messages[0].content == expected_system
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "Here's my question."
|
||||
|
||||
# Check new user message with media is appended (2 media items)
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
assert len(messages[2].content) == 2
|
||||
|
||||
# First part should be image with data URL
|
||||
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert messages[2].content[0].image_url.url.startswith("data:image/")
|
||||
|
||||
# Second part should be file with data URL
|
||||
assert isinstance(messages[2].content[1], OpenAIFile)
|
||||
assert messages[2].content[1].file.file_data.startswith("data:application/pdf;base64,")
|
||||
assert messages[2].content[1].file.filename == "doc.pdf"
|
||||
assert messages[2].content[1].file.file_id is None
|
||||
|
||||
|
||||
async def test_prepend_prompt_with_image_using_image_url(openai_responses_impl, mock_prompts_api):
|
||||
"""Test prepend_prompt with image variable using image_url instead of file_id."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Describe {{screenshot}}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["screenshot"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={
|
||||
"screenshot": OpenAIResponseInputMessageContentImage(
|
||||
image_url="https://example.com/screenshot.png",
|
||||
detail="low",
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
|
||||
# Initial messages
|
||||
messages = [OpenAIUserMessageParam(content="What is this?")]
|
||||
|
||||
# Execute
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
||||
assert len(messages) == 3
|
||||
|
||||
# Check system message has placeholder
|
||||
assert isinstance(messages[0], OpenAISystemMessageParam)
|
||||
assert messages[0].content == "Describe [Image: screenshot]."
|
||||
|
||||
# Check original user message is still there
|
||||
assert isinstance(messages[1], OpenAIUserMessageParam)
|
||||
assert messages[1].content == "What is this?"
|
||||
|
||||
# Check new user message with image is appended
|
||||
assert isinstance(messages[2], OpenAIUserMessageParam)
|
||||
assert isinstance(messages[2].content, list)
|
||||
|
||||
# Image should use the provided URL
|
||||
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
|
||||
assert messages[2].content[0].image_url.url == "https://example.com/screenshot.png"
|
||||
assert messages[2].content[0].image_url.detail == "low"
|
||||
|
||||
|
||||
async def test_prepend_prompt_image_variable_missing_required_fields(openai_responses_impl, mock_prompts_api):
|
||||
"""Test prepend_prompt with image variable that has neither file_id nor image_url - should raise error."""
|
||||
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
|
||||
prompt = Prompt(
|
||||
prompt="Analyze {{bad_image}}.",
|
||||
prompt_id=prompt_id,
|
||||
version=1,
|
||||
variables=["bad_image"],
|
||||
is_default=True,
|
||||
)
|
||||
|
||||
# Create image content with neither file_id nor image_url
|
||||
openai_response_prompt = OpenAIResponsePrompt(
|
||||
id=prompt_id,
|
||||
version="1",
|
||||
variables={"bad_image": OpenAIResponseInputMessageContentImage()}, # No file_id or image_url
|
||||
)
|
||||
|
||||
mock_prompts_api.get_prompt.return_value = prompt
|
||||
messages = [OpenAIUserMessageParam(content="Test")]
|
||||
|
||||
# Execute - should raise ValueError
|
||||
with pytest.raises(ValueError, match="Image content must have either 'image_url' or 'file_id'"):
|
||||
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ def responses_impl_with_conversations(
|
|||
mock_vector_io_api,
|
||||
mock_conversations_api,
|
||||
mock_safety_api,
|
||||
mock_prompts_api,
|
||||
mock_files_api,
|
||||
):
|
||||
"""Create OpenAIResponsesImpl instance with conversations API."""
|
||||
return OpenAIResponsesImpl(
|
||||
|
|
@ -49,6 +51,8 @@ def responses_impl_with_conversations(
|
|||
vector_io_api=mock_vector_io_api,
|
||||
conversations_api=mock_conversations_api,
|
||||
safety_api=mock_safety_api,
|
||||
prompts_api=mock_prompts_api,
|
||||
files_api=mock_files_api,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
|
|
@ -46,6 +48,12 @@ from llama_stack_api.openai_responses import (
|
|||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_files_api():
|
||||
"""Mock files API for testing."""
|
||||
return AsyncMock()
|
||||
|
||||
|
||||
class TestConvertChatChoiceToResponseMessage:
|
||||
async def test_convert_string_content(self):
|
||||
choice = OpenAIChoice(
|
||||
|
|
@ -78,17 +86,17 @@ class TestConvertChatChoiceToResponseMessage:
|
|||
|
||||
|
||||
class TestConvertResponseContentToChatContent:
|
||||
async def test_convert_string_content(self):
|
||||
result = await convert_response_content_to_chat_content("Simple string")
|
||||
async def test_convert_string_content(self, mock_files_api):
|
||||
result = await convert_response_content_to_chat_content("Simple string", mock_files_api)
|
||||
assert result == "Simple string"
|
||||
|
||||
async def test_convert_text_content_parts(self):
|
||||
async def test_convert_text_content_parts(self, mock_files_api):
|
||||
content = [
|
||||
OpenAIResponseInputMessageContentText(text="First part"),
|
||||
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
|
||||
]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
result = await convert_response_content_to_chat_content(content, mock_files_api)
|
||||
|
||||
assert len(result) == 2
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
|
||||
|
|
@ -96,10 +104,10 @@ class TestConvertResponseContentToChatContent:
|
|||
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
|
||||
assert result[1].text == "Second part"
|
||||
|
||||
async def test_convert_image_content(self):
|
||||
async def test_convert_image_content(self, mock_files_api):
|
||||
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
|
||||
|
||||
result = await convert_response_content_to_chat_content(content)
|
||||
result = await convert_response_content_to_chat_content(content, mock_files_api)
|
||||
|
||||
assert len(result) == 1
|
||||
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)
|
||||
|
|
|
|||
|
|
@ -30,6 +30,8 @@ def mock_apis():
|
|||
"vector_io_api": AsyncMock(),
|
||||
"conversations_api": AsyncMock(),
|
||||
"safety_api": AsyncMock(),
|
||||
"prompts_api": AsyncMock(),
|
||||
"files_api": AsyncMock(),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,214 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
"""Tests for making Safety API optional in meta-reference agents provider.
|
||||
|
||||
This test suite validates the changes introduced to fix issue #4165, which
|
||||
allows running the meta-reference agents provider without the Safety API.
|
||||
Safety API is now an optional dependency, and errors are raised at request time
|
||||
when guardrails are explicitly requested without Safety API configured.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.core.datatypes import Api
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
|
||||
from llama_stack.providers.inline.agents.meta_reference import get_provider_impl
|
||||
from llama_stack.providers.inline.agents.meta_reference.config import (
|
||||
AgentPersistenceConfig,
|
||||
MetaReferenceAgentsImplConfig,
|
||||
)
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
|
||||
run_guardrails,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_persistence_config():
|
||||
"""Create a mock persistence configuration."""
|
||||
return AgentPersistenceConfig(
|
||||
agent_state=KVStoreReference(
|
||||
backend="kv_default",
|
||||
namespace="agents",
|
||||
),
|
||||
responses=ResponsesStoreReference(
|
||||
backend="sql_default",
|
||||
table_name="responses",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_deps():
|
||||
"""Create mock dependencies for the agents provider."""
|
||||
# Create mock APIs
|
||||
inference_api = AsyncMock()
|
||||
vector_io_api = AsyncMock()
|
||||
tool_runtime_api = AsyncMock()
|
||||
tool_groups_api = AsyncMock()
|
||||
conversations_api = AsyncMock()
|
||||
prompts_api = AsyncMock()
|
||||
files_api = AsyncMock()
|
||||
|
||||
return {
|
||||
Api.inference: inference_api,
|
||||
Api.vector_io: vector_io_api,
|
||||
Api.tool_runtime: tool_runtime_api,
|
||||
Api.tool_groups: tool_groups_api,
|
||||
Api.conversations: conversations_api,
|
||||
Api.prompts: prompts_api,
|
||||
Api.files: files_api,
|
||||
}
|
||||
|
||||
|
||||
class TestProviderInitialization:
|
||||
"""Test provider initialization with different safety API configurations."""
|
||||
|
||||
async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps):
|
||||
"""Test successful initialization when Safety API is configured."""
|
||||
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
|
||||
|
||||
# Add safety API to deps
|
||||
safety_api = AsyncMock()
|
||||
mock_deps[Api.safety] = safety_api
|
||||
|
||||
# Mock the initialize method to avoid actual initialization
|
||||
with patch(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
assert provider is not None
|
||||
|
||||
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
|
||||
"""Test successful initialization when Safety API is not configured."""
|
||||
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
|
||||
|
||||
# Safety API is NOT in mock_deps - provider should still start
|
||||
# Mock the initialize method to avoid actual initialization
|
||||
with patch(
|
||||
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
assert provider is not None
|
||||
assert provider.safety_api is None
|
||||
|
||||
|
||||
class TestGuardrailsFunctionality:
|
||||
"""Test run_guardrails function with optional safety API."""
|
||||
|
||||
async def test_run_guardrails_with_none_safety_api(self):
|
||||
"""Test that run_guardrails returns None when safety_api is None."""
|
||||
result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
async def test_run_guardrails_with_empty_messages(self):
|
||||
"""Test that run_guardrails returns None for empty messages."""
|
||||
# Test with None safety API
|
||||
result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
# Test with mock safety API
|
||||
mock_safety_api = AsyncMock()
|
||||
result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"])
|
||||
assert result is None
|
||||
|
||||
async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self):
|
||||
"""Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided."""
|
||||
# Should not raise exception, just return None
|
||||
result = await run_guardrails(
|
||||
safety_api=None,
|
||||
messages="potentially harmful content",
|
||||
guardrail_ids=["llama-guard", "content-filter"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps):
|
||||
"""Test that create_openai_response raises error when guardrails requested but Safety API unavailable."""
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
from llama_stack_api import ResponseGuardrailSpec
|
||||
|
||||
# Create OpenAIResponsesImpl with no safety API
|
||||
with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"):
|
||||
impl = OpenAIResponsesImpl(
|
||||
inference_api=mock_deps[Api.inference],
|
||||
tool_groups_api=mock_deps[Api.tool_groups],
|
||||
tool_runtime_api=mock_deps[Api.tool_runtime],
|
||||
responses_store=MagicMock(),
|
||||
vector_io_api=mock_deps[Api.vector_io],
|
||||
safety_api=None, # No Safety API
|
||||
conversations_api=mock_deps[Api.conversations],
|
||||
prompts_api=mock_deps[Api.prompts],
|
||||
files_api=mock_deps[Api.files],
|
||||
)
|
||||
|
||||
# Test with string guardrail
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=["llama-guard"],
|
||||
)
|
||||
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
|
||||
|
||||
# Test with ResponseGuardrailSpec
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=[ResponseGuardrailSpec(type="llama-guard")],
|
||||
)
|
||||
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
|
||||
|
||||
async def test_create_response_succeeds_without_guardrails_and_no_safety_api(
|
||||
self, mock_persistence_config, mock_deps
|
||||
):
|
||||
"""Test that create_openai_response works when no guardrails requested and Safety API unavailable."""
|
||||
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
|
||||
OpenAIResponsesImpl,
|
||||
)
|
||||
|
||||
# Create OpenAIResponsesImpl with no safety API
|
||||
with (
|
||||
patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"),
|
||||
patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream,
|
||||
):
|
||||
# Mock the streaming response to return a simple async generator
|
||||
async def mock_generator():
|
||||
yield MagicMock()
|
||||
|
||||
mock_stream.return_value = mock_generator()
|
||||
|
||||
impl = OpenAIResponsesImpl(
|
||||
inference_api=mock_deps[Api.inference],
|
||||
tool_groups_api=mock_deps[Api.tool_groups],
|
||||
tool_runtime_api=mock_deps[Api.tool_runtime],
|
||||
responses_store=MagicMock(),
|
||||
vector_io_api=mock_deps[Api.vector_io],
|
||||
safety_api=None, # No Safety API
|
||||
conversations_api=mock_deps[Api.conversations],
|
||||
prompts_api=mock_deps[Api.prompts],
|
||||
files_api=mock_deps[Api.files],
|
||||
)
|
||||
|
||||
# Should not raise when no guardrails requested
|
||||
# Note: This will still fail later in execution due to mocking, but should pass the validation
|
||||
try:
|
||||
await impl.create_openai_response(
|
||||
input="test input",
|
||||
model="test-model",
|
||||
guardrails=None, # No guardrails
|
||||
)
|
||||
except Exception as e:
|
||||
# Ensure the error is NOT about missing Safety API
|
||||
assert "Cannot process guardrails: Safety API is not configured" not in str(e)
|
||||
|
|
@ -13,9 +13,9 @@ from unittest.mock import AsyncMock
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore import kvstore_impl, register_kvstore_backends
|
||||
from llama_stack.providers.inline.batches.reference.batches import ReferenceBatchesImpl
|
||||
from llama_stack.providers.inline.batches.reference.config import ReferenceBatchesImplConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl, register_kvstore_backends
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -9,8 +9,8 @@ import pytest
|
|||
from moto import mock_aws
|
||||
|
||||
from llama_stack.core.storage.datatypes import SqliteSqlStoreConfig, SqlStoreReference
|
||||
from llama_stack.core.storage.sqlstore.sqlstore import register_sqlstore_backends
|
||||
from llama_stack.providers.remote.files.s3 import S3FilesImplConfig, get_adapter_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import register_sqlstore_backends
|
||||
|
||||
|
||||
class MockUploadFile:
|
||||
|
|
|
|||
|
|
@ -18,11 +18,11 @@ async def test_listing_hides_other_users_file(s3_provider, sample_text_file):
|
|||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert all(f.id != uploaded.id for f in listed.data)
|
||||
|
|
@ -41,11 +41,11 @@ async def test_cannot_access_other_user_file(s3_provider, sample_text_file, op):
|
|||
user_a = User("user-a", {"roles": ["team-a"]})
|
||||
user_b = User("user-b", {"roles": ["team-b"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
with pytest.raises(ResourceNotFoundError):
|
||||
await op(s3_provider, uploaded.id)
|
||||
|
|
@ -56,11 +56,11 @@ async def test_shared_role_allows_listing(s3_provider, sample_text_file):
|
|||
user_a = User("user-a", {"roles": ["shared-role"]})
|
||||
user_b = User("user-b", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_a
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_b
|
||||
listed = await s3_provider.openai_list_files()
|
||||
assert any(f.id == uploaded.id for f in listed.data)
|
||||
|
|
@ -79,10 +79,10 @@ async def test_shared_role_allows_access(s3_provider, sample_text_file, op):
|
|||
user_x = User("user-x", {"roles": ["shared-role"]})
|
||||
user_y = User("user-y", {"roles": ["shared-role"]})
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_x
|
||||
uploaded = await s3_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
|
||||
|
||||
with patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
with patch("llama_stack.core.storage.sqlstore.authorized_sqlstore.get_authenticated_user") as mock_get_user:
|
||||
mock_get_user.return_value = user_y
|
||||
await op(s3_provider, uploaded.id)
|
||||
|
|
|
|||
|
|
@ -40,8 +40,8 @@ def test_api_key_from_header_overrides_config():
|
|||
"""Test API key from request header overrides config via client property"""
|
||||
config = BedrockConfig(api_key="config-key", region_name="us-east-1")
|
||||
adapter = BedrockInferenceAdapter(config=config)
|
||||
adapter.provider_data_api_key_field = "aws_bedrock_api_key"
|
||||
adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bedrock_api_key="header-key"))
|
||||
adapter.provider_data_api_key_field = "aws_bearer_token_bedrock"
|
||||
adapter.get_request_provider_data = MagicMock(return_value=SimpleNamespace(aws_bearer_token_bedrock="header-key"))
|
||||
|
||||
# The client property is where header override happens (in OpenAIMixin)
|
||||
assert adapter.client.api_key == "header-key"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
|||
|
||||
def test_bedrock_config_defaults_no_env(monkeypatch):
|
||||
"""Test BedrockConfig defaults when env vars are not set"""
|
||||
monkeypatch.delenv("AWS_BEDROCK_API_KEY", raising=False)
|
||||
monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False)
|
||||
monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False)
|
||||
config = BedrockConfig()
|
||||
assert config.auth_credential is None
|
||||
|
|
@ -35,5 +35,5 @@ def test_bedrock_config_sample():
|
|||
sample = BedrockConfig.sample_run_config()
|
||||
assert "api_key" in sample
|
||||
assert "region_name" in sample
|
||||
assert sample["api_key"] == "${env.AWS_BEDROCK_API_KEY:=}"
|
||||
assert sample["api_key"] == "${env.AWS_BEARER_TOKEN_BEDROCK:=}"
|
||||
assert sample["region_name"] == "${env.AWS_DEFAULT_REGION:=us-east-2}"
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
|
|||
VLLMInferenceAdapter,
|
||||
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
|
||||
{
|
||||
"url": "http://fake",
|
||||
"base_url": "http://fake",
|
||||
},
|
||||
),
|
||||
],
|
||||
|
|
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
|
|||
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
|
||||
assumption that there is an OpenAI-compatible client object."""
|
||||
|
||||
inference_adapter = adapter_cls(config=config_cls())
|
||||
inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
|
||||
|
||||
inference_adapter.__provider_spec__ = MagicMock()
|
||||
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ from llama_stack_api import (
|
|||
|
||||
@pytest.fixture(scope="function")
|
||||
async def vllm_inference_adapter():
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||
inference_adapter = VLLMInferenceAdapter(config=config)
|
||||
inference_adapter.model_store = AsyncMock()
|
||||
await inference_adapter.initialize()
|
||||
|
|
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
|
|||
via extra_body to the underlying OpenAI client through the InferenceRouter.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
|
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
|
|||
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
|
||||
"""
|
||||
# Set up the vLLM adapter
|
||||
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
|
||||
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
|
||||
vllm_adapter = VLLMInferenceAdapter(config=config)
|
||||
vllm_adapter.__provider_id__ = "vllm"
|
||||
await vllm_adapter.initialize()
|
||||
|
|
|
|||
|
|
@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
|
|||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
|
|
|||
|
|
@ -4,8 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import get_args, get_origin
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, HttpUrl
|
||||
|
||||
from llama_stack.core.distribution import get_provider_registry, providable_apis
|
||||
from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||
|
|
@ -41,3 +43,55 @@ class TestProviderConfigurations:
|
|||
|
||||
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
|
||||
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
|
||||
|
||||
def test_remote_inference_url_standardization(self):
|
||||
"""Verify all remote inference providers use standardized base_url configuration."""
|
||||
provider_registry = get_provider_registry()
|
||||
inference_providers = provider_registry.get("inference", {})
|
||||
|
||||
# Filter for remote providers only
|
||||
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
|
||||
|
||||
failures = []
|
||||
for provider_type, provider_spec in remote_providers.items():
|
||||
try:
|
||||
config_class_name = provider_spec.config_class
|
||||
config_type = instantiate_class_type(config_class_name)
|
||||
|
||||
# Check that config has base_url field (not url)
|
||||
if hasattr(config_type, "model_fields"):
|
||||
fields = config_type.model_fields
|
||||
|
||||
# Should NOT have 'url' field (old pattern)
|
||||
if "url" in fields:
|
||||
failures.append(
|
||||
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
|
||||
f"Please rename to 'base_url' for consistency."
|
||||
)
|
||||
|
||||
# Should have 'base_url' field with HttpUrl | None type
|
||||
if "base_url" in fields:
|
||||
field_info = fields["base_url"]
|
||||
annotation = field_info.annotation
|
||||
|
||||
# Check if it's HttpUrl or HttpUrl | None
|
||||
# get_origin() returns Union for (X | Y), None for plain types
|
||||
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
|
||||
is_valid = False
|
||||
if get_origin(annotation) is not None: # It's a Union/Optional
|
||||
if HttpUrl in get_args(annotation):
|
||||
is_valid = True
|
||||
elif annotation == HttpUrl: # Plain HttpUrl without | None
|
||||
is_valid = True
|
||||
|
||||
if not is_valid:
|
||||
failures.append(
|
||||
f"{provider_type}: base_url field has incorrect type annotation. "
|
||||
f"Expected 'HttpUrl | None', got '{annotation}'"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
|
||||
|
||||
if failures:
|
||||
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))
|
||||
|
|
|
|||
|
|
@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
|
|||
from llama_stack.core.request_headers import request_provider_data_context
|
||||
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack_api import (
|
||||
Model,
|
||||
ModelType,
|
||||
OpenAIChatCompletionRequestWithExtraBody,
|
||||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAIEmbeddingsRequestWithExtraBody,
|
||||
OpenAIUserMessageParam,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIMixinImpl(OpenAIMixin):
|
||||
|
|
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
|
|||
error_message = str(exc_info.value)
|
||||
assert "test_api_key" in error_message
|
||||
assert "x-llamastack-provider-data" in error_message
|
||||
|
||||
|
||||
class TestOpenAIMixinAllowedModelsInference:
|
||||
"""Test cases for allowed_models enforcement during inference requests"""
|
||||
|
||||
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods succeed with allowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_client.completions.create = AsyncMock(return_value=MagicMock())
|
||||
mock_embedding_response = MagicMock()
|
||||
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
|
||||
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test completion
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
|
||||
)
|
||||
mock_client.completions.create.assert_called_once()
|
||||
|
||||
# Test embeddings
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
|
||||
)
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
|
||||
"""Test that all inference methods fail with disallowed models"""
|
||||
mixin.config.allowed_models = ["gpt-4"]
|
||||
|
||||
mock_client = MagicMock()
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
# Test chat completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
||||
# Test completion with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
|
||||
await mixin.openai_completion(
|
||||
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
|
||||
)
|
||||
|
||||
# Test embeddings with disallowed model
|
||||
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
|
||||
await mixin.openai_embeddings(
|
||||
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
|
||||
)
|
||||
|
||||
mock_client.chat.completions.create.assert_not_called()
|
||||
mock_client.completions.create.assert_not_called()
|
||||
mock_client.embeddings.create.assert_not_called()
|
||||
|
||||
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
|
||||
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
|
||||
# Test with None (no restrictions)
|
||||
assert mixin.config.allowed_models is None
|
||||
mock_client = MagicMock()
|
||||
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
|
||||
|
||||
with mock_client_context(mixin, mock_client):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
mock_client.chat.completions.create.assert_called_once()
|
||||
|
||||
# Test with empty list (blocks all models)
|
||||
mixin.config.allowed_models = []
|
||||
with mock_client_context(mixin, mock_client):
|
||||
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
|
||||
await mixin.openai_chat_completion(
|
||||
OpenAIChatCompletionRequestWithExtraBody(
|
||||
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,13 +11,13 @@ import numpy as np
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.storage.datatypes import KVStoreReference, SqliteKVStoreConfig
|
||||
from llama_stack.core.storage.kvstore import register_kvstore_backends
|
||||
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
|
||||
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
|
||||
from llama_stack.providers.remote.vector_io.pgvector.config import PGVectorVectorIOConfig
|
||||
from llama_stack.providers.remote.vector_io.pgvector.pgvector import PGVectorIndex, PGVectorVectorIOAdapter
|
||||
from llama_stack.providers.utils.kvstore import register_kvstore_backends
|
||||
from llama_stack_api import Chunk, ChunkMetadata, QueryChunksResponse, VectorStore
|
||||
|
||||
EMBEDDING_DIMENSION = 768
|
||||
|
|
@ -279,7 +279,7 @@ async def pgvector_vec_adapter(unique_kvstore_config, mock_inference_api, embedd
|
|||
) as mock_check_version:
|
||||
mock_check_version.return_value = "0.5.1"
|
||||
|
||||
with patch("llama_stack.providers.utils.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||
with patch("llama_stack.core.storage.kvstore.kvstore_impl") as mock_kvstore_impl:
|
||||
mock_kvstore = AsyncMock()
|
||||
mock_kvstore_impl.return_value = mock_kvstore
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
from llama_stack_api import Chunk, ChunkMetadata
|
||||
from llama_stack_api import Chunk, ChunkMetadata, VectorStoreFileObject
|
||||
|
||||
# This test is a unit test for the chunk_utils.py helpers. This should only contain
|
||||
# tests which are specific to this file. More general (API-level) tests should be placed in
|
||||
|
|
@ -78,3 +78,77 @@ def test_chunk_serialization():
|
|||
serialized_chunk = chunk.model_dump()
|
||||
assert serialized_chunk["chunk_id"] == "test-chunk-id"
|
||||
assert "chunk_id" in serialized_chunk
|
||||
|
||||
|
||||
def test_vector_store_file_object_attributes_validation():
|
||||
"""Test VectorStoreFileObject validates and sanitizes attributes at input boundary."""
|
||||
# Test with metadata containing lists, nested dicts, and primitives
|
||||
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
|
||||
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-123",
|
||||
attributes={
|
||||
"tags": ["transformers", "h100-compatible", "region:us"], # List -> string
|
||||
"model_name": "granite-3.3-8b", # String preserved
|
||||
"score": 0.95, # Float preserved
|
||||
"active": True, # Bool preserved
|
||||
"count": 42, # Int -> float
|
||||
"nested": {"key": "value"}, # Dict filtered out
|
||||
},
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
|
||||
# Lists converted to comma-separated strings
|
||||
assert file_obj.attributes["tags"] == "transformers, h100-compatible, region:us"
|
||||
# Primitives preserved
|
||||
assert file_obj.attributes["model_name"] == "granite-3.3-8b"
|
||||
assert file_obj.attributes["score"] == 0.95
|
||||
assert file_obj.attributes["active"] is True
|
||||
assert file_obj.attributes["count"] == 42.0 # int -> float
|
||||
# Complex types filtered out
|
||||
assert "nested" not in file_obj.attributes
|
||||
|
||||
|
||||
def test_vector_store_file_object_attributes_constraints():
|
||||
"""Test VectorStoreFileObject enforces OpenAPI constraints on attributes."""
|
||||
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
|
||||
|
||||
# Test max 16 properties
|
||||
many_attrs = {f"key{i}": f"value{i}" for i in range(20)}
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-123",
|
||||
attributes=many_attrs,
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
assert len(file_obj.attributes) == 16 # Max 16 properties
|
||||
|
||||
# Test max 64 char keys are filtered
|
||||
long_key_attrs = {"a" * 65: "value", "valid_key": "value"}
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-124",
|
||||
attributes=long_key_attrs,
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
assert "a" * 65 not in file_obj.attributes
|
||||
assert "valid_key" in file_obj.attributes
|
||||
|
||||
# Test max 512 char string values are truncated
|
||||
long_value_attrs = {"key": "x" * 600}
|
||||
file_obj = VectorStoreFileObject(
|
||||
id="file-125",
|
||||
attributes=long_value_attrs,
|
||||
chunking_strategy=VectorStoreChunkingStrategyAuto(),
|
||||
created_at=1234567890,
|
||||
status="completed",
|
||||
vector_store_id="vs-123",
|
||||
)
|
||||
assert len(file_obj.attributes["key"]) == 512
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue