Merge branch 'main' into fix/nvidia-safety-provider-endpoint-4189

This commit is contained in:
Roy Belio 2025-11-20 13:30:11 +02:00 committed by GitHub
commit f8f28344a5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
117 changed files with 16294 additions and 769 deletions

View file

@ -25,6 +25,13 @@ from llama_stack.providers.utils.responses.responses_store import (
ResponsesStore,
_OpenAIResponseObjectWithInputAndMessages,
)
from llama_stack_api import (
OpenAIChatCompletionContentPartImageParam,
OpenAIFile,
OpenAIFileObject,
OpenAISystemMessageParam,
Prompt,
)
from llama_stack_api.agents import Order
from llama_stack_api.inference import (
OpenAIAssistantMessageParam,
@ -38,6 +45,8 @@ from llama_stack_api.inference import (
)
from llama_stack_api.openai_responses import (
ListOpenAIResponseInputItem,
OpenAIResponseInputMessageContentFile,
OpenAIResponseInputMessageContentImage,
OpenAIResponseInputMessageContentText,
OpenAIResponseInputToolFunction,
OpenAIResponseInputToolMCP,
@ -47,6 +56,7 @@ from llama_stack_api.openai_responses import (
OpenAIResponseOutputMessageFunctionToolCall,
OpenAIResponseOutputMessageMCPCall,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponsePrompt,
OpenAIResponseText,
OpenAIResponseTextFormat,
WebSearchToolTypes,
@ -98,6 +108,19 @@ def mock_safety_api():
return safety_api
@pytest.fixture
def mock_prompts_api():
prompts_api = AsyncMock()
return prompts_api
@pytest.fixture
def mock_files_api():
"""Mock files API for testing."""
files_api = AsyncMock()
return files_api
@pytest.fixture
def openai_responses_impl(
mock_inference_api,
@ -107,6 +130,8 @@ def openai_responses_impl(
mock_vector_io_api,
mock_safety_api,
mock_conversations_api,
mock_prompts_api,
mock_files_api,
):
return OpenAIResponsesImpl(
inference_api=mock_inference_api,
@ -116,6 +141,8 @@ def openai_responses_impl(
vector_io_api=mock_vector_io_api,
safety_api=mock_safety_api,
conversations_api=mock_conversations_api,
prompts_api=mock_prompts_api,
files_api=mock_files_api,
)
@ -499,7 +526,7 @@ async def test_create_openai_response_with_tool_call_function_arguments_none(ope
mock_inference_api.openai_chat_completion.return_value = fake_stream_toolcall()
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api, mock_files_api):
"""Test creating an OpenAI response with multiple messages."""
# Setup
input_messages = [
@ -710,7 +737,7 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
async def test_create_openai_response_with_instructions_and_multiple_messages(
openai_responses_impl, mock_inference_api
openai_responses_impl, mock_inference_api, mock_files_api
):
# Setup
input_messages = [
@ -1242,3 +1269,489 @@ async def test_create_openai_response_with_output_types_as_input(
assert stored_with_outputs.input == input_with_output_types
assert len(stored_with_outputs.input) == 3
async def test_create_openai_response_with_prompt(openai_responses_impl, mock_inference_api, mock_prompts_api):
"""Test creating an OpenAI response with a prompt."""
input_text = "What is the capital of Ireland?"
model = "meta-llama/Llama-3.1-8B-Instruct"
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="You are a helpful {{ area_name }} assistant at {{ company_name }}. Always provide accurate information.",
prompt_id=prompt_id,
version=1,
variables=["area_name", "company_name"],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"area_name": OpenAIResponseInputMessageContentText(text="geography"),
"company_name": OpenAIResponseInputMessageContentText(text="Dummy Company"),
},
)
mock_prompts_api.get_prompt.return_value = prompt
mock_inference_api.openai_chat_completion.return_value = fake_stream()
result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
prompt=openai_response_prompt,
)
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
mock_inference_api.openai_chat_completion.assert_called()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.args[0].messages
assert len(sent_messages) == 2
system_messages = [msg for msg in sent_messages if msg.role == "system"]
assert len(system_messages) == 1
assert (
system_messages[0].content
== "You are a helpful geography assistant at Dummy Company. Always provide accurate information."
)
user_messages = [msg for msg in sent_messages if msg.role == "user"]
assert len(user_messages) == 1
assert user_messages[0].content == input_text
assert result.model == model
assert result.status == "completed"
assert isinstance(result.prompt, OpenAIResponsePrompt)
assert result.prompt.id == prompt_id
assert result.prompt.variables == openai_response_prompt.variables
assert result.prompt.version == "1"
async def test_prepend_prompt_successful_without_variables(openai_responses_impl, mock_prompts_api, mock_inference_api):
"""Test prepend_prompt function without variables."""
input_text = "What is the capital of Ireland?"
model = "meta-llama/Llama-3.1-8B-Instruct"
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="You are a helpful assistant. Always provide accurate information.",
prompt_id=prompt_id,
version=1,
variables=[],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
mock_prompts_api.get_prompt.return_value = prompt
mock_inference_api.openai_chat_completion.return_value = fake_stream()
await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
prompt=openai_response_prompt,
)
mock_prompts_api.get_prompt.assert_called_with(prompt_id, 1)
mock_inference_api.openai_chat_completion.assert_called()
call_args = mock_inference_api.openai_chat_completion.call_args
sent_messages = call_args.args[0].messages
assert len(sent_messages) == 2
system_messages = [msg for msg in sent_messages if msg.role == "system"]
assert system_messages[0].content == "You are a helpful assistant. Always provide accurate information."
async def test_prepend_prompt_invalid_variable(openai_responses_impl, mock_prompts_api):
"""Test error handling in prepend_prompt function when prompt parameters contain invalid variables."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="You are a {{ role }} assistant.",
prompt_id=prompt_id,
version=1,
variables=["role"], # Only "role" is valid
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"role": OpenAIResponseInputMessageContentText(text="helpful"),
"company": OpenAIResponseInputMessageContentText(
text="Dummy Company"
), # company is not in prompt.variables
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Test prompt")]
# Execute - should raise ValueError for invalid variable
with pytest.raises(ValueError, match="Variable company not found in prompt"):
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
# Verify
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
async def test_prepend_prompt_not_found(openai_responses_impl, mock_prompts_api):
"""Test prepend_prompt function when prompt is not found."""
prompt_id = "pmpt_nonexistent"
openai_response_prompt = OpenAIResponsePrompt(id=prompt_id, version="1")
mock_prompts_api.get_prompt.return_value = None # Prompt not found
# Initial messages
messages = [OpenAIUserMessageParam(content="Test prompt")]
initial_length = len(messages)
# Execute
result = await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
# Verify
mock_prompts_api.get_prompt.assert_called_once_with(prompt_id, 1)
# Should return None when prompt not found
assert result is None
# Messages should not be modified
assert len(messages) == initial_length
assert messages[0].content == "Test prompt"
async def test_prepend_prompt_variable_substitution(openai_responses_impl, mock_prompts_api):
"""Test complex variable substitution with multiple occurrences and special characters in prepend_prompt function."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
# Support all whitespace variations: {{name}}, {{ name }}, {{ name}}, {{name }}, etc.
prompt = Prompt(
prompt="Hello {{name}}! You are working at {{ company}}. Your role is {{role}} at {{company}}. Remember, {{ name }}, to be {{ tone }}.",
prompt_id=prompt_id,
version=1,
variables=["name", "company", "role", "tone"],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"name": OpenAIResponseInputMessageContentText(text="Alice"),
"company": OpenAIResponseInputMessageContentText(text="Dummy Company"),
"role": OpenAIResponseInputMessageContentText(text="AI Assistant"),
"tone": OpenAIResponseInputMessageContentText(text="professional"),
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Test")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
# Verify
assert len(messages) == 2
assert isinstance(messages[0], OpenAISystemMessageParam)
expected_content = "Hello Alice! You are working at Dummy Company. Your role is AI Assistant at Dummy Company. Remember, Alice, to be professional."
assert messages[0].content == expected_content
async def test_prepend_prompt_with_image_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
"""Test prepend_prompt with image variable - should create placeholder in system message and append image as separate user message."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Analyze this {{product_image}} and describe what you see.",
prompt_id=prompt_id,
version=1,
variables=["product_image"],
is_default=True,
)
# Mock file content and file metadata
mock_file_content = b"fake_image_data"
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
object="file",
id="file-abc123",
bytes=len(mock_file_content),
created_at=1234567890,
expires_at=1234567890,
filename="product.jpg",
purpose="assistants",
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"product_image": OpenAIResponseInputMessageContentImage(
file_id="file-abc123",
detail="high",
)
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="What do you think?")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has placeholder
assert isinstance(messages[0], OpenAISystemMessageParam)
assert messages[0].content == "Analyze this [Image: product_image] and describe what you see."
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "What do you think?"
# Check new user message with image is appended
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
assert len(messages[2].content) == 1
# Should be image with data URL
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
assert messages[2].content[0].image_url.url.startswith("data:image/")
assert messages[2].content[0].image_url.detail == "high"
async def test_prepend_prompt_with_file_variable(openai_responses_impl, mock_prompts_api, mock_files_api):
"""Test prepend_prompt with file variable - should create placeholder in system message and append file as separate user message."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Review the document {{contract_file}} and summarize key points.",
prompt_id=prompt_id,
version=1,
variables=["contract_file"],
is_default=True,
)
# Mock file retrieval
mock_file_content = b"fake_pdf_content"
mock_files_api.openai_retrieve_file_content.return_value = type("obj", (object,), {"body": mock_file_content})()
mock_files_api.openai_retrieve_file.return_value = OpenAIFileObject(
object="file",
id="file-contract-789",
bytes=len(mock_file_content),
created_at=1234567890,
expires_at=1234567890,
filename="contract.pdf",
purpose="assistants",
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"contract_file": OpenAIResponseInputMessageContentFile(
file_id="file-contract-789",
filename="contract.pdf",
)
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Please review this.")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has placeholder
assert isinstance(messages[0], OpenAISystemMessageParam)
assert messages[0].content == "Review the document [File: contract_file] and summarize key points."
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "Please review this."
# Check new user message with file is appended
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
assert len(messages[2].content) == 1
# First part should be file with data URL
assert isinstance(messages[2].content[0], OpenAIFile)
assert messages[2].content[0].file.file_data.startswith("data:application/pdf;base64,")
assert messages[2].content[0].file.filename == "contract.pdf"
assert messages[2].content[0].file.file_id is None
async def test_prepend_prompt_with_mixed_variables(openai_responses_impl, mock_prompts_api, mock_files_api):
"""Test prepend_prompt with text, image, and file variables mixed together."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Hello {{name}}! Analyze {{photo}} and review {{document}}. Provide insights for {{company}}.",
prompt_id=prompt_id,
version=1,
variables=["name", "photo", "document", "company"],
is_default=True,
)
# Mock file retrieval for image and file
mock_image_content = b"fake_image_data"
mock_file_content = b"fake_doc_content"
async def mock_retrieve_file_content(file_id):
if file_id == "file-photo-123":
return type("obj", (object,), {"body": mock_image_content})()
elif file_id == "file-doc-456":
return type("obj", (object,), {"body": mock_file_content})()
mock_files_api.openai_retrieve_file_content.side_effect = mock_retrieve_file_content
def mock_retrieve_file(file_id):
if file_id == "file-photo-123":
return OpenAIFileObject(
object="file",
id="file-photo-123",
bytes=len(mock_image_content),
created_at=1234567890,
expires_at=1234567890,
filename="photo.jpg",
purpose="assistants",
)
elif file_id == "file-doc-456":
return OpenAIFileObject(
object="file",
id="file-doc-456",
bytes=len(mock_file_content),
created_at=1234567890,
expires_at=1234567890,
filename="doc.pdf",
purpose="assistants",
)
mock_files_api.openai_retrieve_file.side_effect = mock_retrieve_file
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"name": OpenAIResponseInputMessageContentText(text="Alice"),
"photo": OpenAIResponseInputMessageContentImage(file_id="file-photo-123", detail="auto"),
"document": OpenAIResponseInputMessageContentFile(file_id="file-doc-456", filename="doc.pdf"),
"company": OpenAIResponseInputMessageContentText(text="Acme Corp"),
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="Here's my question.")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has text and placeholders
assert isinstance(messages[0], OpenAISystemMessageParam)
expected_system = "Hello Alice! Analyze [Image: photo] and review [File: document]. Provide insights for Acme Corp."
assert messages[0].content == expected_system
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "Here's my question."
# Check new user message with media is appended (2 media items)
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
assert len(messages[2].content) == 2
# First part should be image with data URL
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
assert messages[2].content[0].image_url.url.startswith("data:image/")
# Second part should be file with data URL
assert isinstance(messages[2].content[1], OpenAIFile)
assert messages[2].content[1].file.file_data.startswith("data:application/pdf;base64,")
assert messages[2].content[1].file.filename == "doc.pdf"
assert messages[2].content[1].file.file_id is None
async def test_prepend_prompt_with_image_using_image_url(openai_responses_impl, mock_prompts_api):
"""Test prepend_prompt with image variable using image_url instead of file_id."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Describe {{screenshot}}.",
prompt_id=prompt_id,
version=1,
variables=["screenshot"],
is_default=True,
)
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={
"screenshot": OpenAIResponseInputMessageContentImage(
image_url="https://example.com/screenshot.png",
detail="low",
)
},
)
mock_prompts_api.get_prompt.return_value = prompt
# Initial messages
messages = [OpenAIUserMessageParam(content="What is this?")]
# Execute
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)
assert len(messages) == 3
# Check system message has placeholder
assert isinstance(messages[0], OpenAISystemMessageParam)
assert messages[0].content == "Describe [Image: screenshot]."
# Check original user message is still there
assert isinstance(messages[1], OpenAIUserMessageParam)
assert messages[1].content == "What is this?"
# Check new user message with image is appended
assert isinstance(messages[2], OpenAIUserMessageParam)
assert isinstance(messages[2].content, list)
# Image should use the provided URL
assert isinstance(messages[2].content[0], OpenAIChatCompletionContentPartImageParam)
assert messages[2].content[0].image_url.url == "https://example.com/screenshot.png"
assert messages[2].content[0].image_url.detail == "low"
async def test_prepend_prompt_image_variable_missing_required_fields(openai_responses_impl, mock_prompts_api):
"""Test prepend_prompt with image variable that has neither file_id nor image_url - should raise error."""
prompt_id = "pmpt_1234567890abcdef1234567890abcdef1234567890abcdef"
prompt = Prompt(
prompt="Analyze {{bad_image}}.",
prompt_id=prompt_id,
version=1,
variables=["bad_image"],
is_default=True,
)
# Create image content with neither file_id nor image_url
openai_response_prompt = OpenAIResponsePrompt(
id=prompt_id,
version="1",
variables={"bad_image": OpenAIResponseInputMessageContentImage()}, # No file_id or image_url
)
mock_prompts_api.get_prompt.return_value = prompt
messages = [OpenAIUserMessageParam(content="Test")]
# Execute - should raise ValueError
with pytest.raises(ValueError, match="Image content must have either 'image_url' or 'file_id'"):
await openai_responses_impl._prepend_prompt(messages, openai_response_prompt)

View file

@ -39,6 +39,8 @@ def responses_impl_with_conversations(
mock_vector_io_api,
mock_conversations_api,
mock_safety_api,
mock_prompts_api,
mock_files_api,
):
"""Create OpenAIResponsesImpl instance with conversations API."""
return OpenAIResponsesImpl(
@ -49,6 +51,8 @@ def responses_impl_with_conversations(
vector_io_api=mock_vector_io_api,
conversations_api=mock_conversations_api,
safety_api=mock_safety_api,
prompts_api=mock_prompts_api,
files_api=mock_files_api,
)

View file

@ -5,6 +5,8 @@
# the root directory of this source tree.
from unittest.mock import AsyncMock
import pytest
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
@ -46,6 +48,12 @@ from llama_stack_api.openai_responses import (
)
@pytest.fixture
def mock_files_api():
"""Mock files API for testing."""
return AsyncMock()
class TestConvertChatChoiceToResponseMessage:
async def test_convert_string_content(self):
choice = OpenAIChoice(
@ -78,17 +86,17 @@ class TestConvertChatChoiceToResponseMessage:
class TestConvertResponseContentToChatContent:
async def test_convert_string_content(self):
result = await convert_response_content_to_chat_content("Simple string")
async def test_convert_string_content(self, mock_files_api):
result = await convert_response_content_to_chat_content("Simple string", mock_files_api)
assert result == "Simple string"
async def test_convert_text_content_parts(self):
async def test_convert_text_content_parts(self, mock_files_api):
content = [
OpenAIResponseInputMessageContentText(text="First part"),
OpenAIResponseOutputMessageContentOutputText(text="Second part"),
]
result = await convert_response_content_to_chat_content(content)
result = await convert_response_content_to_chat_content(content, mock_files_api)
assert len(result) == 2
assert isinstance(result[0], OpenAIChatCompletionContentPartTextParam)
@ -96,10 +104,10 @@ class TestConvertResponseContentToChatContent:
assert isinstance(result[1], OpenAIChatCompletionContentPartTextParam)
assert result[1].text == "Second part"
async def test_convert_image_content(self):
async def test_convert_image_content(self, mock_files_api):
content = [OpenAIResponseInputMessageContentImage(image_url="https://example.com/image.jpg", detail="high")]
result = await convert_response_content_to_chat_content(content)
result = await convert_response_content_to_chat_content(content, mock_files_api)
assert len(result) == 1
assert isinstance(result[0], OpenAIChatCompletionContentPartImageParam)

View file

@ -30,6 +30,8 @@ def mock_apis():
"vector_io_api": AsyncMock(),
"conversations_api": AsyncMock(),
"safety_api": AsyncMock(),
"prompts_api": AsyncMock(),
"files_api": AsyncMock(),
}

View file

@ -0,0 +1,214 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""Tests for making Safety API optional in meta-reference agents provider.
This test suite validates the changes introduced to fix issue #4165, which
allows running the meta-reference agents provider without the Safety API.
Safety API is now an optional dependency, and errors are raised at request time
when guardrails are explicitly requested without Safety API configured.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.core.datatypes import Api
from llama_stack.core.storage.datatypes import KVStoreReference, ResponsesStoreReference
from llama_stack.providers.inline.agents.meta_reference import get_provider_impl
from llama_stack.providers.inline.agents.meta_reference.config import (
AgentPersistenceConfig,
MetaReferenceAgentsImplConfig,
)
from llama_stack.providers.inline.agents.meta_reference.responses.utils import (
run_guardrails,
)
@pytest.fixture
def mock_persistence_config():
"""Create a mock persistence configuration."""
return AgentPersistenceConfig(
agent_state=KVStoreReference(
backend="kv_default",
namespace="agents",
),
responses=ResponsesStoreReference(
backend="sql_default",
table_name="responses",
),
)
@pytest.fixture
def mock_deps():
"""Create mock dependencies for the agents provider."""
# Create mock APIs
inference_api = AsyncMock()
vector_io_api = AsyncMock()
tool_runtime_api = AsyncMock()
tool_groups_api = AsyncMock()
conversations_api = AsyncMock()
prompts_api = AsyncMock()
files_api = AsyncMock()
return {
Api.inference: inference_api,
Api.vector_io: vector_io_api,
Api.tool_runtime: tool_runtime_api,
Api.tool_groups: tool_groups_api,
Api.conversations: conversations_api,
Api.prompts: prompts_api,
Api.files: files_api,
}
class TestProviderInitialization:
"""Test provider initialization with different safety API configurations."""
async def test_initialization_with_safety_api_present(self, mock_persistence_config, mock_deps):
"""Test successful initialization when Safety API is configured."""
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
# Add safety API to deps
safety_api = AsyncMock()
mock_deps[Api.safety] = safety_api
# Mock the initialize method to avoid actual initialization
with patch(
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
assert provider is not None
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
"""Test successful initialization when Safety API is not configured."""
config = MetaReferenceAgentsImplConfig(persistence=mock_persistence_config)
# Safety API is NOT in mock_deps - provider should still start
# Mock the initialize method to avoid actual initialization
with patch(
"llama_stack.providers.inline.agents.meta_reference.agents.MetaReferenceAgentsImpl.initialize",
new_callable=AsyncMock,
):
# Should not raise any exception
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
assert provider is not None
assert provider.safety_api is None
class TestGuardrailsFunctionality:
"""Test run_guardrails function with optional safety API."""
async def test_run_guardrails_with_none_safety_api(self):
"""Test that run_guardrails returns None when safety_api is None."""
result = await run_guardrails(safety_api=None, messages="test message", guardrail_ids=["llama-guard"])
assert result is None
async def test_run_guardrails_with_empty_messages(self):
"""Test that run_guardrails returns None for empty messages."""
# Test with None safety API
result = await run_guardrails(safety_api=None, messages="", guardrail_ids=["llama-guard"])
assert result is None
# Test with mock safety API
mock_safety_api = AsyncMock()
result = await run_guardrails(safety_api=mock_safety_api, messages="", guardrail_ids=["llama-guard"])
assert result is None
async def test_run_guardrails_with_none_safety_api_ignores_guardrails(self):
"""Test that guardrails are skipped when safety_api is None, even if guardrail_ids are provided."""
# Should not raise exception, just return None
result = await run_guardrails(
safety_api=None,
messages="potentially harmful content",
guardrail_ids=["llama-guard", "content-filter"],
)
assert result is None
async def test_create_response_rejects_guardrails_without_safety_api(self, mock_persistence_config, mock_deps):
"""Test that create_openai_response raises error when guardrails requested but Safety API unavailable."""
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
from llama_stack_api import ResponseGuardrailSpec
# Create OpenAIResponsesImpl with no safety API
with patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"):
impl = OpenAIResponsesImpl(
inference_api=mock_deps[Api.inference],
tool_groups_api=mock_deps[Api.tool_groups],
tool_runtime_api=mock_deps[Api.tool_runtime],
responses_store=MagicMock(),
vector_io_api=mock_deps[Api.vector_io],
safety_api=None, # No Safety API
conversations_api=mock_deps[Api.conversations],
prompts_api=mock_deps[Api.prompts],
files_api=mock_deps[Api.files],
)
# Test with string guardrail
with pytest.raises(ValueError) as exc_info:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=["llama-guard"],
)
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
# Test with ResponseGuardrailSpec
with pytest.raises(ValueError) as exc_info:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=[ResponseGuardrailSpec(type="llama-guard")],
)
assert "Cannot process guardrails: Safety API is not configured" in str(exc_info.value)
async def test_create_response_succeeds_without_guardrails_and_no_safety_api(
self, mock_persistence_config, mock_deps
):
"""Test that create_openai_response works when no guardrails requested and Safety API unavailable."""
from llama_stack.providers.inline.agents.meta_reference.responses.openai_responses import (
OpenAIResponsesImpl,
)
# Create OpenAIResponsesImpl with no safety API
with (
patch("llama_stack.providers.inline.agents.meta_reference.responses.openai_responses.ResponsesStore"),
patch.object(OpenAIResponsesImpl, "_create_streaming_response", new_callable=AsyncMock) as mock_stream,
):
# Mock the streaming response to return a simple async generator
async def mock_generator():
yield MagicMock()
mock_stream.return_value = mock_generator()
impl = OpenAIResponsesImpl(
inference_api=mock_deps[Api.inference],
tool_groups_api=mock_deps[Api.tool_groups],
tool_runtime_api=mock_deps[Api.tool_runtime],
responses_store=MagicMock(),
vector_io_api=mock_deps[Api.vector_io],
safety_api=None, # No Safety API
conversations_api=mock_deps[Api.conversations],
prompts_api=mock_deps[Api.prompts],
files_api=mock_deps[Api.files],
)
# Should not raise when no guardrails requested
# Note: This will still fail later in execution due to mocking, but should pass the validation
try:
await impl.create_openai_response(
input="test input",
model="test-model",
guardrails=None, # No guardrails
)
except Exception as e:
# Ensure the error is NOT about missing Safety API
assert "Cannot process guardrails: Safety API is not configured" not in str(e)

View file

@ -120,7 +120,7 @@ from llama_stack.providers.remote.inference.watsonx.watsonx import WatsonXInfere
VLLMInferenceAdapter,
"llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
{
"url": "http://fake",
"base_url": "http://fake",
},
),
],
@ -153,7 +153,7 @@ def test_litellm_provider_data_used(config_cls, adapter_cls, provider_data_valid
"""Validate data for LiteLLM-based providers. Similar to test_openai_provider_data_used, but without the
assumption that there is an OpenAI-compatible client object."""
inference_adapter = adapter_cls(config=config_cls())
inference_adapter = adapter_cls(config=config_cls(base_url="http://fake"))
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = provider_data_validator

View file

@ -40,7 +40,7 @@ from llama_stack_api import (
@pytest.fixture(scope="function")
async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config=config)
inference_adapter.model_store = AsyncMock()
await inference_adapter.initialize()
@ -204,7 +204,7 @@ async def test_vllm_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()
@ -277,7 +277,7 @@ async def test_vllm_chat_completion_extra_body():
via extra_body to the underlying OpenAI client through the InferenceRouter for chat completion.
"""
# Set up the vLLM adapter
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
config = VLLMInferenceAdapterConfig(base_url="http://mocked.localhost:12345")
vllm_adapter = VLLMInferenceAdapter(config=config)
vllm_adapter.__provider_id__ = "vllm"
await vllm_adapter.initialize()

View file

@ -146,7 +146,7 @@ async def test_hosted_model_not_in_endpoint_mapping():
async def test_self_hosted_ignores_endpoint():
adapter = create_adapter(
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
config=NVIDIAConfig(base_url="http://localhost:8000", api_key=None),
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
)
mock_session = MockSession(MockResponse())

View file

@ -4,8 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import get_args, get_origin
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, HttpUrl
from llama_stack.core.distribution import get_provider_registry, providable_apis
from llama_stack.core.utils.dynamic import instantiate_class_type
@ -41,3 +43,55 @@ class TestProviderConfigurations:
sample_config = config_type.sample_run_config(__distro_dir__="foobarbaz")
assert isinstance(sample_config, dict), f"{config_class_name}.sample_run_config() did not return a dict"
def test_remote_inference_url_standardization(self):
"""Verify all remote inference providers use standardized base_url configuration."""
provider_registry = get_provider_registry()
inference_providers = provider_registry.get("inference", {})
# Filter for remote providers only
remote_providers = {k: v for k, v in inference_providers.items() if k.startswith("remote::")}
failures = []
for provider_type, provider_spec in remote_providers.items():
try:
config_class_name = provider_spec.config_class
config_type = instantiate_class_type(config_class_name)
# Check that config has base_url field (not url)
if hasattr(config_type, "model_fields"):
fields = config_type.model_fields
# Should NOT have 'url' field (old pattern)
if "url" in fields:
failures.append(
f"{provider_type}: Uses deprecated 'url' field instead of 'base_url'. "
f"Please rename to 'base_url' for consistency."
)
# Should have 'base_url' field with HttpUrl | None type
if "base_url" in fields:
field_info = fields["base_url"]
annotation = field_info.annotation
# Check if it's HttpUrl or HttpUrl | None
# get_origin() returns Union for (X | Y), None for plain types
# get_args() returns the types inside Union, e.g. (HttpUrl, NoneType)
is_valid = False
if get_origin(annotation) is not None: # It's a Union/Optional
if HttpUrl in get_args(annotation):
is_valid = True
elif annotation == HttpUrl: # Plain HttpUrl without | None
is_valid = True
if not is_valid:
failures.append(
f"{provider_type}: base_url field has incorrect type annotation. "
f"Expected 'HttpUrl | None', got '{annotation}'"
)
except Exception as e:
failures.append(f"{provider_type}: Error checking URL standardization: {str(e)}")
if failures:
pytest.fail("URL standardization violations found:\n" + "\n".join(f" - {f}" for f in failures))

View file

@ -15,7 +15,14 @@ from pydantic import BaseModel, Field
from llama_stack.core.request_headers import request_provider_data_context
from llama_stack.providers.utils.inference.model_registry import RemoteInferenceProviderConfig
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack_api import Model, ModelType, OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack_api import (
Model,
ModelType,
OpenAIChatCompletionRequestWithExtraBody,
OpenAICompletionRequestWithExtraBody,
OpenAIEmbeddingsRequestWithExtraBody,
OpenAIUserMessageParam,
)
class OpenAIMixinImpl(OpenAIMixin):
@ -834,3 +841,96 @@ class TestOpenAIMixinProviderDataApiKey:
error_message = str(exc_info.value)
assert "test_api_key" in error_message
assert "x-llamastack-provider-data" in error_message
class TestOpenAIMixinAllowedModelsInference:
"""Test cases for allowed_models enforcement during inference requests"""
async def test_inference_with_allowed_models(self, mixin, mock_client_context):
"""Test that all inference methods succeed with allowed models"""
mixin.config.allowed_models = ["gpt-4", "text-davinci-003", "text-embedding-ada-002"]
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
mock_client.completions.create = AsyncMock(return_value=MagicMock())
mock_embedding_response = MagicMock()
mock_embedding_response.data = [MagicMock(embedding=[0.1, 0.2, 0.3])]
mock_embedding_response.usage = MagicMock(prompt_tokens=5, total_tokens=5)
mock_client.embeddings.create = AsyncMock(return_value=mock_embedding_response)
with mock_client_context(mixin, mock_client):
# Test chat completion
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test completion
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-003", prompt="Hello")
)
mock_client.completions.create.assert_called_once()
# Test embeddings
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-ada-002", input="test text")
)
mock_client.embeddings.create.assert_called_once()
async def test_inference_with_disallowed_models(self, mixin, mock_client_context):
"""Test that all inference methods fail with disallowed models"""
mixin.config.allowed_models = ["gpt-4"]
mock_client = MagicMock()
with mock_client_context(mixin, mock_client):
# Test chat completion with disallowed model
with pytest.raises(ValueError, match="Model 'gpt-4-turbo' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4-turbo", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
# Test completion with disallowed model
with pytest.raises(ValueError, match="Model 'text-davinci-002' is not in the allowed models list"):
await mixin.openai_completion(
OpenAICompletionRequestWithExtraBody(model="text-davinci-002", prompt="Hello")
)
# Test embeddings with disallowed model
with pytest.raises(ValueError, match="Model 'text-embedding-3-large' is not in the allowed models list"):
await mixin.openai_embeddings(
OpenAIEmbeddingsRequestWithExtraBody(model="text-embedding-3-large", input="test text")
)
mock_client.chat.completions.create.assert_not_called()
mock_client.completions.create.assert_not_called()
mock_client.embeddings.create.assert_not_called()
async def test_inference_with_no_restrictions(self, mixin, mock_client_context):
"""Test that inference succeeds when allowed_models is None or empty list blocks all"""
# Test with None (no restrictions)
assert mixin.config.allowed_models is None
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock(return_value=MagicMock())
with mock_client_context(mixin, mock_client):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="any-model", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)
mock_client.chat.completions.create.assert_called_once()
# Test with empty list (blocks all models)
mixin.config.allowed_models = []
with mock_client_context(mixin, mock_client):
with pytest.raises(ValueError, match="Model 'gpt-4' is not in the allowed models list"):
await mixin.openai_chat_completion(
OpenAIChatCompletionRequestWithExtraBody(
model="gpt-4", messages=[OpenAIUserMessageParam(role="user", content="Hello")]
)
)

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
from llama_stack_api import Chunk, ChunkMetadata
from llama_stack_api import Chunk, ChunkMetadata, VectorStoreFileObject
# This test is a unit test for the chunk_utils.py helpers. This should only contain
# tests which are specific to this file. More general (API-level) tests should be placed in
@ -78,3 +78,77 @@ def test_chunk_serialization():
serialized_chunk = chunk.model_dump()
assert serialized_chunk["chunk_id"] == "test-chunk-id"
assert "chunk_id" in serialized_chunk
def test_vector_store_file_object_attributes_validation():
"""Test VectorStoreFileObject validates and sanitizes attributes at input boundary."""
# Test with metadata containing lists, nested dicts, and primitives
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
file_obj = VectorStoreFileObject(
id="file-123",
attributes={
"tags": ["transformers", "h100-compatible", "region:us"], # List -> string
"model_name": "granite-3.3-8b", # String preserved
"score": 0.95, # Float preserved
"active": True, # Bool preserved
"count": 42, # Int -> float
"nested": {"key": "value"}, # Dict filtered out
},
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
# Lists converted to comma-separated strings
assert file_obj.attributes["tags"] == "transformers, h100-compatible, region:us"
# Primitives preserved
assert file_obj.attributes["model_name"] == "granite-3.3-8b"
assert file_obj.attributes["score"] == 0.95
assert file_obj.attributes["active"] is True
assert file_obj.attributes["count"] == 42.0 # int -> float
# Complex types filtered out
assert "nested" not in file_obj.attributes
def test_vector_store_file_object_attributes_constraints():
"""Test VectorStoreFileObject enforces OpenAPI constraints on attributes."""
from llama_stack_api.vector_io import VectorStoreChunkingStrategyAuto
# Test max 16 properties
many_attrs = {f"key{i}": f"value{i}" for i in range(20)}
file_obj = VectorStoreFileObject(
id="file-123",
attributes=many_attrs,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
assert len(file_obj.attributes) == 16 # Max 16 properties
# Test max 64 char keys are filtered
long_key_attrs = {"a" * 65: "value", "valid_key": "value"}
file_obj = VectorStoreFileObject(
id="file-124",
attributes=long_key_attrs,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
assert "a" * 65 not in file_obj.attributes
assert "valid_key" in file_obj.attributes
# Test max 512 char string values are truncated
long_value_attrs = {"key": "x" * 600}
file_obj = VectorStoreFileObject(
id="file-125",
attributes=long_value_attrs,
chunking_strategy=VectorStoreChunkingStrategyAuto(),
created_at=1234567890,
status="completed",
vector_store_id="vs-123",
)
assert len(file_obj.attributes["key"]) == 512