Merge branch 'main' into vllm_health_check

This commit is contained in:
Sumit Jaiswal 2025-06-05 18:09:36 +05:30 committed by GitHub
commit c18b585d32
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
143 changed files with 9210 additions and 5347 deletions

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
# we want the mcp server to be authenticated OR not, depends
from collections.abc import Callable
from contextlib import contextmanager
# Unfortunately the toolgroup id must be tied to the tool names because the registry
@ -13,15 +14,158 @@ from contextlib import contextmanager
MCP_TOOLGROUP_ID = "mcp::localmcp"
def default_tools():
"""Default tools for backward compatibility."""
from mcp import types
from mcp.server.fastmcp import Context
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
async def get_boiling_point(liquid_name: str, celsius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celsius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celsius: Whether to return the boiling point in Celsius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "myawesomeliquid":
if celsius:
return -100
else:
return -212
else:
return -1
return {"greet_everyone": greet_everyone, "get_boiling_point": get_boiling_point}
def dependency_tools():
"""Tools with natural dependencies for multi-turn testing."""
from mcp import types
from mcp.server.fastmcp import Context
async def get_user_id(username: str, ctx: Context) -> str:
"""
Get the user ID for a given username. This ID is needed for other operations.
:param username: The username to look up
:return: The user ID for the username
"""
# Simple mapping for testing
user_mapping = {"alice": "user_12345", "bob": "user_67890", "charlie": "user_11111", "admin": "user_00000"}
return user_mapping.get(username.lower(), "user_99999")
async def get_user_permissions(user_id: str, ctx: Context) -> str:
"""
Get the permissions for a user ID. Requires a valid user ID from get_user_id.
:param user_id: The user ID to check permissions for
:return: The permissions for the user
"""
# Permission mapping based on user IDs
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
return permission_mapping.get(user_id, "none")
async def check_file_access(user_id: str, filename: str, ctx: Context) -> str:
"""
Check if a user can access a specific file. Requires a valid user ID.
:param user_id: The user ID to check access for
:param filename: The filename to check access to
:return: Whether the user can access the file (yes/no)
"""
# Get permissions first
permission_mapping = {
"user_12345": "read,write", # alice
"user_67890": "read", # bob
"user_11111": "admin", # charlie
"user_00000": "superadmin", # admin
"user_99999": "none", # unknown users
}
permissions = permission_mapping.get(user_id, "none")
# Check file access based on permissions and filename
if permissions == "superadmin":
access = "yes"
elif permissions == "admin":
access = "yes" if not filename.startswith("secret_") else "no"
elif "write" in permissions:
access = "yes" if filename.endswith(".txt") else "no"
elif "read" in permissions:
access = "yes" if filename.endswith(".txt") or filename.endswith(".md") else "no"
else:
access = "no"
return [types.TextContent(type="text", text=access)]
async def get_experiment_id(experiment_name: str, ctx: Context) -> str:
"""
Get the experiment ID for a given experiment name. This ID is needed to get results.
:param experiment_name: The name of the experiment
:return: The experiment ID
"""
# Simple mapping for testing
experiment_mapping = {
"temperature_test": "exp_001",
"pressure_test": "exp_002",
"chemical_reaction": "exp_003",
"boiling_point": "exp_004",
}
exp_id = experiment_mapping.get(experiment_name.lower(), "exp_999")
return exp_id
async def get_experiment_results(experiment_id: str, ctx: Context) -> str:
"""
Get the results for an experiment ID. Requires a valid experiment ID from get_experiment_id.
:param experiment_id: The experiment ID to get results for
:return: The experiment results
"""
# Results mapping based on experiment IDs
results_mapping = {
"exp_001": "Temperature: 25°C, Status: Success",
"exp_002": "Pressure: 1.2 atm, Status: Success",
"exp_003": "Yield: 85%, Status: Complete",
"exp_004": "Boiling Point: 100°C, Status: Verified",
"exp_999": "No results found",
}
results = results_mapping.get(experiment_id, "Invalid experiment ID")
return results
return {
"get_user_id": get_user_id,
"get_user_permissions": get_user_permissions,
"check_file_access": check_file_access,
"get_experiment_id": get_experiment_id,
"get_experiment_results": get_experiment_results,
}
@contextmanager
def make_mcp_server(required_auth_token: str | None = None):
def make_mcp_server(required_auth_token: str | None = None, tools: dict[str, Callable] | None = None):
"""
Create an MCP server with the specified tools.
:param required_auth_token: Optional auth token required for access
:param tools: Dictionary of tool_name -> tool_function. If None, uses default tools.
"""
import threading
import time
import httpx
import uvicorn
from mcp import types
from mcp.server.fastmcp import Context, FastMCP
from mcp.server.fastmcp import FastMCP
from mcp.server.sse import SseServerTransport
from starlette.applications import Starlette
from starlette.responses import Response
@ -29,35 +173,18 @@ def make_mcp_server(required_auth_token: str | None = None):
server = FastMCP("FastMCP Test Server", log_level="WARNING")
@server.tool()
async def greet_everyone(
url: str, ctx: Context
) -> list[types.TextContent | types.ImageContent | types.EmbeddedResource]:
return [types.TextContent(type="text", text="Hello, world!")]
tools = tools or default_tools()
@server.tool()
async def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit.
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else:
return -1
# Register all tools with the server
for tool_func in tools.values():
server.tool()(tool_func)
sse = SseServerTransport("/messages/")
async def handle_sse(request):
from starlette.exceptions import HTTPException
auth_header = request.headers.get("Authorization")
auth_header: str | None = request.headers.get("Authorization")
auth_token = None
if auth_header and auth_header.startswith("Bearer "):
auth_token = auth_header.split(" ")[1]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,51 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from io import BytesIO
def test_openai_client_basic_operations(openai_client):
"""Test basic file operations through OpenAI client."""
client = openai_client
test_content = b"files test content"
try:
# Upload file using OpenAI client
with BytesIO(test_content) as file_buffer:
file_buffer.name = "openai_test.txt"
uploaded_file = client.files.create(file=file_buffer, purpose="assistants")
# Verify basic response structure
assert uploaded_file.id.startswith("file-")
assert hasattr(uploaded_file, "filename")
# List files
files_list = client.files.list()
file_ids = [f.id for f in files_list.data]
assert uploaded_file.id in file_ids
# Retrieve file info
retrieved_file = client.files.retrieve(uploaded_file.id)
assert retrieved_file.id == uploaded_file.id
# Retrieve file content - OpenAI client returns httpx Response object
content_response = client.files.content(uploaded_file.id)
# The response is an httpx Response object with .content attribute containing bytes
content = content_response.content
assert content == test_content
# Delete file
delete_response = client.files.delete(uploaded_file.id)
assert delete_response.deleted is True
except Exception as e:
# Cleanup in case of failure
try:
client.files.delete(uploaded_file.id)
except Exception:
pass
raise e

View file

@ -121,7 +121,7 @@ class ToolGroupsImpl(Impl):
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry)
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple models and verify listing
@ -163,7 +163,7 @@ async def test_models_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry)
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple shields and verify listing
@ -179,14 +179,14 @@ async def test_shields_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry)
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_providere": InferenceImpl()}, cached_disk_dist_registry)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_providere",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
@ -209,7 +209,7 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry)
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple datasets and verify listing
@ -235,7 +235,7 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry)
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple scoring functions and verify listing
@ -261,7 +261,7 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry)
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple benchmarks and verify listing
@ -279,7 +279,7 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry)
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize()
# Register multiple tool groups and verify listing

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,334 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose
from llama_stack.providers.inline.files.localfs import (
LocalfsFilesImpl,
LocalfsFilesImplConfig,
)
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
class MockUploadFile:
"""Mock UploadFile for testing file uploads."""
def __init__(self, content: bytes, filename: str, content_type: str = "text/plain"):
self.content = content
self.filename = filename
self.content_type = content_type
async def read(self):
return self.content
@pytest_asyncio.fixture
async def files_provider(tmp_path):
"""Create a files provider with temporary storage for testing."""
storage_dir = tmp_path / "files"
db_path = tmp_path / "files_metadata.db"
config = LocalfsFilesImplConfig(
storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix())
)
provider = LocalfsFilesImpl(config)
await provider.initialize()
yield provider
@pytest.fixture
def sample_text_file():
"""Sample text file for testing."""
content = b"Hello, this is a test file for the OpenAI Files API!"
return MockUploadFile(content, "test.txt", "text/plain")
@pytest.fixture
def sample_json_file():
"""Sample JSON file for testing."""
content = b'{"message": "Hello, World!", "type": "test"}'
return MockUploadFile(content, "data.json", "application/json")
@pytest.fixture
def large_file():
"""Large file for testing file size handling."""
content = b"x" * 1024 * 1024 # 1MB file
return MockUploadFile(content, "large_file.bin", "application/octet-stream")
class TestOpenAIFilesAPI:
"""Test suite for OpenAI Files API endpoints."""
@pytest.mark.asyncio
async def test_upload_file_success(self, files_provider, sample_text_file):
"""Test successful file upload."""
# Upload file
result = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
# Verify response
assert result.id.startswith("file-")
assert result.filename == "test.txt"
assert result.purpose == OpenAIFilePurpose.ASSISTANTS
assert result.bytes == len(sample_text_file.content)
assert result.created_at > 0
assert result.expires_at > result.created_at
@pytest.mark.asyncio
async def test_upload_different_purposes(self, files_provider, sample_text_file):
"""Test uploading files with different purposes."""
purposes = list(OpenAIFilePurpose)
uploaded_files = []
for purpose in purposes:
result = await files_provider.openai_upload_file(file=sample_text_file, purpose=purpose)
uploaded_files.append(result)
assert result.purpose == purpose
@pytest.mark.asyncio
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
"""Test uploading different types and sizes of files."""
files_to_test = [
(sample_text_file, "test.txt"),
(sample_json_file, "data.json"),
(large_file, "large_file.bin"),
]
for file_obj, expected_filename in files_to_test:
result = await files_provider.openai_upload_file(file=file_obj, purpose=OpenAIFilePurpose.ASSISTANTS)
assert result.filename == expected_filename
assert result.bytes == len(file_obj.content)
@pytest.mark.asyncio
async def test_list_files_empty(self, files_provider):
"""Test listing files when no files exist."""
result = await files_provider.openai_list_files()
assert result.data == []
assert result.has_more is False
assert result.first_id == ""
assert result.last_id == ""
@pytest.mark.asyncio
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
"""Test listing files when files exist."""
# Upload multiple files
file1 = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
file2 = await files_provider.openai_upload_file(file=sample_json_file, purpose=OpenAIFilePurpose.ASSISTANTS)
# List files
result = await files_provider.openai_list_files()
assert len(result.data) == 2
file_ids = [f.id for f in result.data]
assert file1.id in file_ids
assert file2.id in file_ids
@pytest.mark.asyncio
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
"""Test listing files with purpose filtering."""
# Upload file with specific purpose
uploaded_file = await files_provider.openai_upload_file(
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
# List files with matching purpose
result = await files_provider.openai_list_files(purpose=OpenAIFilePurpose.ASSISTANTS)
assert len(result.data) == 1
assert result.data[0].id == uploaded_file.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
@pytest.mark.asyncio
async def test_list_files_with_limit(self, files_provider, sample_text_file):
"""Test listing files with limit parameter."""
# Upload multiple files
for _ in range(5):
await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
# List with limit
result = await files_provider.openai_list_files(limit=3)
assert len(result.data) == 3
@pytest.mark.asyncio
async def test_list_files_with_order(self, files_provider, sample_text_file):
"""Test listing files with different order."""
# Upload multiple files
files = []
for _ in range(3):
file = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
files.append(file)
# Test descending order (default)
result_desc = await files_provider.openai_list_files(order=Order.desc)
assert len(result_desc.data) == 3
# Most recent should be first
assert result_desc.data[0].created_at >= result_desc.data[1].created_at >= result_desc.data[2].created_at
# Test ascending order
result_asc = await files_provider.openai_list_files(order=Order.asc)
assert len(result_asc.data) == 3
# Oldest should be first
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
@pytest.mark.asyncio
async def test_retrieve_file_success(self, files_provider, sample_text_file):
"""Test successful file retrieval."""
# Upload file
uploaded_file = await files_provider.openai_upload_file(
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
# Retrieve file
retrieved_file = await files_provider.openai_retrieve_file(uploaded_file.id)
# Verify response
assert retrieved_file.id == uploaded_file.id
assert retrieved_file.filename == uploaded_file.filename
assert retrieved_file.purpose == uploaded_file.purpose
assert retrieved_file.bytes == uploaded_file.bytes
assert retrieved_file.created_at == uploaded_file.created_at
assert retrieved_file.expires_at == uploaded_file.expires_at
@pytest.mark.asyncio
async def test_retrieve_file_not_found(self, files_provider):
"""Test retrieving a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file("file-nonexistent")
@pytest.mark.asyncio
async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
"""Test successful file content retrieval."""
# Upload file
uploaded_file = await files_provider.openai_upload_file(
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
# Retrieve file content
content = await files_provider.openai_retrieve_file_content(uploaded_file.id)
# Verify content
assert content.body == sample_text_file.content
@pytest.mark.asyncio
async def test_retrieve_file_content_not_found(self, files_provider):
"""Test retrieving content of a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file_content("file-nonexistent")
@pytest.mark.asyncio
async def test_delete_file_success(self, files_provider, sample_text_file):
"""Test successful file deletion."""
# Upload file
uploaded_file = await files_provider.openai_upload_file(
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
# Verify file exists
await files_provider.openai_retrieve_file(uploaded_file.id)
# Delete file
delete_response = await files_provider.openai_delete_file(uploaded_file.id)
# Verify delete response
assert delete_response.id == uploaded_file.id
assert delete_response.deleted is True
# Verify file no longer exists
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
await files_provider.openai_retrieve_file(uploaded_file.id)
@pytest.mark.asyncio
async def test_delete_file_not_found(self, files_provider):
"""Test deleting a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_delete_file("file-nonexistent")
@pytest.mark.asyncio
async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
"""Test that files persist correctly across multiple operations."""
# Upload file
uploaded_file = await files_provider.openai_upload_file(
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
# Verify it appears in listing
files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 1
assert files_list.data[0].id == uploaded_file.id
# Retrieve file info
retrieved_file = await files_provider.openai_retrieve_file(uploaded_file.id)
assert retrieved_file.id == uploaded_file.id
# Retrieve file content
content = await files_provider.openai_retrieve_file_content(uploaded_file.id)
assert content.body == sample_text_file.content
# Delete file
await files_provider.openai_delete_file(uploaded_file.id)
# Verify it's gone from listing
files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 0
@pytest.mark.asyncio
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
"""Test operations with multiple files."""
# Upload multiple files
file1 = await files_provider.openai_upload_file(file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS)
file2 = await files_provider.openai_upload_file(file=sample_json_file, purpose=OpenAIFilePurpose.ASSISTANTS)
# Verify both exist
files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 2
# Delete one file
await files_provider.openai_delete_file(file1.id)
# Verify only one remains
files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 1
assert files_list.data[0].id == file2.id
# Verify the remaining file is still accessible
content = await files_provider.openai_retrieve_file_content(file2.id)
assert content.body == sample_json_file.content
@pytest.mark.asyncio
async def test_file_id_uniqueness(self, files_provider, sample_text_file):
"""Test that each uploaded file gets a unique ID."""
file_ids = set()
# Upload same file multiple times
for _ in range(10):
uploaded_file = await files_provider.openai_upload_file(
file=sample_text_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
assert uploaded_file.id not in file_ids, f"Duplicate file ID: {uploaded_file.id}"
file_ids.add(uploaded_file.id)
assert uploaded_file.id.startswith("file-")
@pytest.mark.asyncio
async def test_file_no_filename_handling(self, files_provider):
"""Test handling files with no filename."""
file_without_name = MockUploadFile(b"content", None) # No filename
uploaded_file = await files_provider.openai_upload_file(
file=file_without_name, purpose=OpenAIFilePurpose.ASSISTANTS
)
assert uploaded_file.filename == "uploaded_file" # Default filename
@pytest.mark.asyncio
async def test_after_pagination_not_implemented(self, files_provider):
"""Test that 'after' pagination raises NotImplementedError."""
with pytest.raises(NotImplementedError, match="After pagination not yet implemented"):
await files_provider.openai_list_files(after="file-some-id")

View file

@ -0,0 +1,177 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import time
from pathlib import Path
from unittest.mock import patch
import pytest
from tiktoken.load import load_tiktoken_bpe
from llama_stack.models.llama.tokenizer_utils import load_bpe_file
@pytest.fixture
def test_bpe_content():
"""Sample BPE file content for testing."""
return """wA== 0
wQ== 1
9Q== 2
9g== 3
9w== 4
+A== 5
+Q== 6
+g== 7
+w== 8
/A== 9
/Q== 10
/g== 11
/w== 12
AA== 13
AQ== 14"""
@pytest.fixture
def test_bpe_file(tmp_path, test_bpe_content):
"""Create a temporary BPE file for testing."""
bpe_file = tmp_path / "test_tokenizer.model"
bpe_file.write_text(test_bpe_content, encoding="utf-8")
return bpe_file
@pytest.fixture
def llama3_model_path():
"""Path to Llama3 tokenizer model."""
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama3/tokenizer.model"
@pytest.fixture
def llama4_model_path():
"""Path to Llama4 tokenizer model."""
return Path(__file__).parent / "../../../../llama_stack/models/llama/llama4/tokenizer.model"
def test_load_bpe_file_basic_functionality(test_bpe_file):
"""Test that load_bpe_file correctly parses BPE files."""
result = load_bpe_file(test_bpe_file)
for key, value in result.items():
assert isinstance(key, bytes)
assert isinstance(value, int)
assert len(result) == 15
expected_first_token = base64.b64decode("wA==")
assert expected_first_token in result
assert result[expected_first_token] == 0
def test_load_bpe_file_vs_tiktoken_with_real_model(llama3_model_path):
"""Test that our implementation produces identical results to tiktoken on real model files."""
if not llama3_model_path.exists():
pytest.skip("Llama3 tokenizer model not found")
our_result = load_bpe_file(llama3_model_path)
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
# Compare results from our implementation and tiktoken
assert len(our_result) == len(tiktoken_result)
assert our_result == tiktoken_result
assert len(our_result) > 100000
ranks = list(our_result.values())
assert len(ranks) == len(set(ranks))
def test_load_bpe_file_vs_tiktoken_with_llama4_model(llama4_model_path):
"""Test that our implementation produces identical results to tiktoken on Llama4 model."""
if not llama4_model_path.exists():
pytest.skip("Llama4 tokenizer model not found")
our_result = load_bpe_file(llama4_model_path)
tiktoken_result = load_tiktoken_bpe(llama4_model_path.as_posix())
# Compare results from our implementation and tiktoken
assert len(our_result) == len(tiktoken_result)
assert our_result == tiktoken_result
assert len(our_result) > 100000
ranks = list(our_result.values())
assert len(ranks) == len(set(ranks))
def test_load_bpe_file_malformed_lines(tmp_path):
"""Test that load_bpe_file handles malformed lines gracefully."""
malformed_content = """wA== 0
invalid_line_without_rank
wQ== 1
invalid_base64!!! 2
9Q== 2"""
test_file = tmp_path / "malformed.model"
test_file.write_text(malformed_content, encoding="utf-8")
with patch("llama_stack.models.llama.tokenizer_utils.logger") as mock_logger:
result = load_bpe_file(test_file)
# Should have 3 valid entries (skipping malformed ones)
assert len(result) == 3
# Should have logged warnings for malformed lines
assert mock_logger.warning.called
assert mock_logger.warning.call_count > 0
def test_load_bpe_file_nonexistent_file():
"""Test that load_bpe_file raises appropriate error for nonexistent files."""
with pytest.raises(FileNotFoundError):
load_bpe_file("/nonexistent/path/to/file.model")
def test_tokenizer_integration():
"""Test that our load_bpe_file works correctly when used in actual tokenizers."""
try:
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
tokenizer = Llama3Tokenizer.get_instance()
# Test basic functionality
test_text = "Hello, world! This is a test."
tokens = tokenizer.encode(test_text, bos=False, eos=False)
decoded = tokenizer.decode(tokens)
assert test_text == decoded
assert isinstance(tokens, list)
assert all(isinstance(token, int) for token in tokens)
except Exception as e:
pytest.skip(f"Llama3 tokenizer not available: {e}")
def test_performance_comparison(llama3_model_path):
"""Test that our implementation has reasonable performance compared to tiktoken."""
if not llama3_model_path.exists():
pytest.skip("Llama3 tokenizer model not found")
# Time our implementation
start_time = time.time()
our_result = load_bpe_file(llama3_model_path)
our_time = time.time() - start_time
# Time tiktoken implementation
start_time = time.time()
tiktoken_result = load_tiktoken_bpe(llama3_model_path.as_posix())
tiktoken_time = time.time() - start_time
# Verify results are identical
assert our_result == tiktoken_result
# Our implementation should be reasonably fast (within 10x of tiktoken)
# This is a loose bound since we're optimizing for correctness, not speed
assert our_time < tiktoken_time * 10, f"Our implementation took {our_time:.3f}s vs tiktoken's {tiktoken_time:.3f}s"
print(f"Performance comparison - Our: {our_time:.3f}s, Tiktoken: {tiktoken_time:.3f}s")

View file

@ -59,6 +59,7 @@ async def agents_impl(config, mock_apis):
mock_apis["safety_api"],
mock_apis["tool_runtime_api"],
mock_apis["tool_groups_api"],
{},
)
await impl.initialize()
yield impl

View file

@ -25,11 +25,17 @@ from llama_stack.apis.agents.openai_responses import (
OpenAIResponseObjectWithInput,
OpenAIResponseOutputMessageContentOutputText,
OpenAIResponseOutputMessageWebSearchToolCall,
OpenAIResponseText,
OpenAIResponseTextFormat,
)
from llama_stack.apis.inference.inference import (
OpenAIAssistantMessageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIDeveloperMessageParam,
OpenAIJSONSchema,
OpenAIResponseFormatJSONObject,
OpenAIResponseFormatJSONSchema,
OpenAIResponseFormatText,
OpenAIUserMessageParam,
)
from llama_stack.apis.tools.tools import Tool, ToolGroups, ToolInvocationResult, ToolParameter, ToolRuntime
@ -96,6 +102,7 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
mock_inference_api.openai_chat_completion.assert_called_once_with(
model=model,
messages=[OpenAIUserMessageParam(role="user", content="What is the capital of Ireland?", name=None)],
response_format=OpenAIResponseFormatText(),
tools=None,
stream=False,
temperature=0.1,
@ -224,16 +231,16 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
],
)
# Verify
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Verify inference API was called correctly (after iterating over result)
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["tools"] is not None
assert first_call.kwargs["temperature"] == 0.1
# Check that we got the content from our mocked tool execution result
chunks = [chunk async for chunk in result]
assert len(chunks) == 2 # Should have response.created and response.completed
# Check response.created event (should have empty output)
assert chunks[0].type == "response.created"
assert len(chunks[0].response.output) == 0
@ -320,6 +327,7 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
model="fake_model",
output=[response_output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
)
mock_responses_store.get_response_object.return_value = previous_response
@ -362,6 +370,7 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
model="fake_model",
output=[output_web_search, output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
)
mock_responses_store.get_response_object.return_value = response
@ -483,6 +492,7 @@ async def test_create_openai_response_with_instructions_and_previous_response(
model="fake_model",
output=[response_output_message],
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[input_item_message],
)
mock_responses_store.get_response_object.return_value = response
@ -576,6 +586,7 @@ async def test_responses_store_list_input_items_logic():
object="response",
status="completed",
output=[],
text=OpenAIResponseText(format=(OpenAIResponseTextFormat(type="text"))),
input=input_items,
)
@ -644,6 +655,7 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
created_at=1234567890,
model="meta-llama/Llama-3.1-8B-Instruct",
status="completed",
text=OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")),
input=[
OpenAIResponseMessage(
id="msg-prev-user", role="user", content=[OpenAIResponseInputMessageContentText(text="What is 2+2?")]
@ -694,3 +706,61 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
# Verify the response itself is correct
assert result.model == model
assert result.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize(
"text_format, response_format",
[
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="text")), OpenAIResponseFormatText()),
(
OpenAIResponseText(format=OpenAIResponseTextFormat(name="Test", schema={"foo": "bar"}, type="json_schema")),
OpenAIResponseFormatJSONSchema(json_schema=OpenAIJSONSchema(name="Test", schema={"foo": "bar"})),
),
(OpenAIResponseText(format=OpenAIResponseTextFormat(type="json_object")), OpenAIResponseFormatJSONObject()),
# ensure text param with no format specified defaults to text
(OpenAIResponseText(format=None), OpenAIResponseFormatText()),
# ensure text param of None defaults to text
(None, OpenAIResponseFormatText()),
],
)
async def test_create_openai_response_with_text_format(
openai_responses_impl, mock_inference_api, text_format, response_format
):
"""Test creating Responses with text formats."""
# Setup
input_text = "How hot it is in San Francisco today?"
model = "meta-llama/Llama-3.1-8B-Instruct"
# Load the chat completion fixture
mock_chat_completion = load_chat_completion_fixture("simple_chat_completion.yaml")
mock_inference_api.openai_chat_completion.return_value = mock_chat_completion
# Execute
_result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
text=text_format,
)
# Verify
first_call = mock_inference_api.openai_chat_completion.call_args_list[0]
assert first_call.kwargs["messages"][0].content == input_text
assert first_call.kwargs["response_format"] is not None
assert first_call.kwargs["response_format"] == response_format
@pytest.mark.asyncio
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with an invalid text format."""
# Setup
input_text = "How hot it is in San Francisco today?"
model = "meta-llama/Llama-3.1-8B-Instruct"
# Execute
with pytest.raises(ValueError):
_result = await openai_responses_impl.create_openai_response(
input=input_text,
model=model,
text=OpenAIResponseText(format={"type": "invalid"}),
)

View file

@ -12,24 +12,24 @@ import pytest
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore)
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_creation_with_access_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Set creator's attributes for the session
creator_attributes = {"roles": ["researcher"], "teams": ["ai-team"]}
mock_get_auth_attributes.return_value = creator_attributes
mock_get_authenticated_user.return_value = User("test_user", creator_attributes)
# Create a session
session_id = await agent_persistence.create_session("Test Session")
@ -37,14 +37,15 @@ async def test_session_creation_with_access_attributes(mock_get_auth_attributes,
# Get the session and verify access attributes were set
session_info = await agent_persistence.get_session_info(session_id)
assert session_info is not None
assert session_info.access_attributes is not None
assert session_info.access_attributes.roles == ["researcher"]
assert session_info.access_attributes.teams == ["ai-team"]
assert session_info.owner is not None
assert session_info.owner.attributes is not None
assert session_info.owner.attributes["roles"] == ["researcher"]
assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_session_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with specific access attributes
@ -53,8 +54,9 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]),
owner=User("someone", {"roles": ["admin"], "teams": ["security-team"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -63,20 +65,22 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup):
)
# User with matching attributes can access
mock_get_auth_attributes.return_value = {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
mock_get_authenticated_user.return_value = User(
"testuser", {"roles": ["admin", "user"], "teams": ["security-team", "other-team"]}
)
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is not None
assert retrieved_session.session_id == session_id
# User without matching attributes cannot access
mock_get_auth_attributes.return_value = {"roles": ["user"], "teams": ["other-team"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"], "teams": ["other-team"]})
retrieved_session = await agent_persistence.get_session_info(session_id)
assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
@ -85,8 +89,9 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -109,7 +114,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
)
# Admin can add turn
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.add_turn_to_session(session_id, turn)
# Admin can get turn
@ -118,7 +123,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
assert retrieved_turn.turn_id == turn_id
# Regular user cannot get turn
mock_get_auth_attributes.return_value = {"roles": ["user"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
with pytest.raises(ValueError):
await agent_persistence.get_session_turn(session_id, turn_id)
@ -128,8 +133,8 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_auth_attributes")
async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes, test_setup):
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup
# Create a session with restricted access
@ -138,8 +143,9 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
session_id=session_id,
session_name="Restricted Session",
started_at=datetime.now(),
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("someone", {"roles": ["admin"]}),
turns=[],
identifier="Restricted Session",
)
await agent_persistence.kvstore.set(
@ -150,7 +156,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
turn_id = str(uuid.uuid4())
# Admin user can set inference iterations
mock_get_auth_attributes.return_value = {"roles": ["admin"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["admin"]})
await agent_persistence.set_num_infer_iters_in_turn(session_id, turn_id, 5)
# Admin user can get inference iterations
@ -158,7 +164,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes
assert infer_iters == 5
# Regular user cannot get inference iterations
mock_get_auth_attributes.return_value = {"roles": ["user"]}
mock_get_authenticated_user.return_value = User("testuser", {"roles": ["user"]})
infer_iters = await agent_persistence.get_num_infer_iters_in_turn(session_id, turn_id)
assert infer_iters is None

View file

@ -70,9 +70,12 @@ class MockInferenceAdapterWithSleep:
# ruff: noqa: N802
def do_POST(self):
time.sleep(sleep_time)
response_body = json.dumps(response).encode("utf-8")
self.send_response(code=200)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(response_body))
self.end_headers()
self.wfile.write(json.dumps(response).encode("utf-8"))
self.wfile.write(response_body)
self.request_handler = DelayedRequestHandler

View file

@ -8,19 +8,18 @@
import pytest
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import ModelWithACL
from llama_stack.distribution.server.auth_providers import AccessAttributes
from llama_stack.distribution.datatypes import ModelWithOwner, User
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry
@pytest.mark.asyncio
async def test_registry_cache_with_acl(cached_disk_dist_registry):
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-acl",
provider_id="test-provider",
provider_resource_id="model-acl-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"], teams=["ai-team"]),
owner=User("testuser", {"roles": ["admin"], "teams": ["ai-team"]}),
)
success = await cached_disk_dist_registry.register(model)
@ -29,22 +28,14 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-acl")
assert cached_model is not None
assert cached_model.identifier == "model-acl"
assert cached_model.access_attributes.roles == ["admin"]
assert cached_model.access_attributes.teams == ["ai-team"]
assert cached_model.owner.principal == "testuser"
assert cached_model.owner.attributes["roles"] == ["admin"]
assert cached_model.owner.attributes["teams"] == ["ai-team"]
fetched_model = await cached_disk_dist_registry.get("model", "model-acl")
assert fetched_model is not None
assert fetched_model.identifier == "model-acl"
assert fetched_model.access_attributes.roles == ["admin"]
model.access_attributes = AccessAttributes(roles=["admin", "user"], projects=["project-x"])
await cached_disk_dist_registry.update(model)
updated_cached = cached_disk_dist_registry.get_cached("model", "model-acl")
assert updated_cached is not None
assert updated_cached.access_attributes.roles == ["admin", "user"]
assert updated_cached.access_attributes.projects == ["project-x"]
assert updated_cached.access_attributes.teams is None
assert fetched_model.owner.attributes["roles"] == ["admin"]
new_registry = CachedDiskDistributionRegistry(cached_disk_dist_registry.kvstore)
await new_registry.initialize()
@ -52,35 +43,32 @@ async def test_registry_cache_with_acl(cached_disk_dist_registry):
new_model = await new_registry.get("model", "model-acl")
assert new_model is not None
assert new_model.identifier == "model-acl"
assert new_model.access_attributes.roles == ["admin", "user"]
assert new_model.access_attributes.projects == ["project-x"]
assert new_model.access_attributes.teams is None
assert new_model.owner.principal == "testuser"
assert new_model.owner.attributes["roles"] == ["admin"]
assert new_model.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
async def test_registry_empty_acl(cached_disk_dist_registry):
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-empty-acl",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
owner=User("testuser", None),
)
await cached_disk_dist_registry.register(model)
cached_model = cached_disk_dist_registry.get_cached("model", "model-empty-acl")
assert cached_model is not None
assert cached_model.access_attributes is not None
assert cached_model.access_attributes.roles is None
assert cached_model.access_attributes.teams is None
assert cached_model.access_attributes.projects is None
assert cached_model.access_attributes.namespaces is None
assert cached_model.owner is not None
assert cached_model.owner.attributes is None
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 1
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-no-acl",
provider_id="test-provider",
provider_resource_id="model-resource-2",
@ -91,7 +79,7 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
cached_model = cached_disk_dist_registry.get_cached("model", "model-no-acl")
assert cached_model is not None
assert cached_model.access_attributes is None
assert cached_model.owner is None
all_models = await cached_disk_dist_registry.get_all()
assert len(all_models) == 2
@ -99,19 +87,19 @@ async def test_registry_empty_acl(cached_disk_dist_registry):
@pytest.mark.asyncio
async def test_registry_serialization(cached_disk_dist_registry):
attributes = AccessAttributes(
roles=["admin", "researcher"],
teams=["ai-team", "ml-team"],
projects=["project-a", "project-b"],
namespaces=["prod", "staging"],
)
attributes = {
"roles": ["admin", "researcher"],
"teams": ["ai-team", "ml-team"],
"projects": ["project-a", "project-b"],
"namespaces": ["prod", "staging"],
}
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-serialize",
provider_id="test-provider",
provider_resource_id="model-resource",
model_type=ModelType.llm,
access_attributes=attributes,
owner=User("bob", attributes),
)
await cached_disk_dist_registry.register(model)
@ -122,7 +110,7 @@ async def test_registry_serialization(cached_disk_dist_registry):
loaded_model = await new_registry.get("model", "model-serialize")
assert loaded_model is not None
assert loaded_model.access_attributes.roles == ["admin", "researcher"]
assert loaded_model.access_attributes.teams == ["ai-team", "ml-team"]
assert loaded_model.access_attributes.projects == ["project-a", "project-b"]
assert loaded_model.access_attributes.namespaces == ["prod", "staging"]
assert loaded_model.owner.attributes["roles"] == ["admin", "researcher"]
assert loaded_model.owner.attributes["teams"] == ["ai-team", "ml-team"]
assert loaded_model.owner.attributes["projects"] == ["project-a", "project-b"]
assert loaded_model.owner.attributes["namespaces"] == ["prod", "staging"]

View file

@ -7,10 +7,13 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import yaml
from pydantic import TypeAdapter, ValidationError
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.distribution.datatypes import AccessAttributes, ModelWithACL
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.distribution.datatypes import AccessRule, ModelWithOwner, User
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
@ -32,39 +35,40 @@ async def test_setup(cached_disk_dist_registry):
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy={},
)
yield cached_disk_dist_registry, routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_with_cache(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-public",
provider_id="test_provider",
provider_resource_id="model-public",
model_type=ModelType.llm,
)
model_admin_only = ModelWithACL(
model_admin_only = ModelWithOwner(
identifier="model-admin",
provider_id="test_provider",
provider_resource_id="model-admin",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("testuser", {"roles": ["admin"]}),
)
model_data_scientist = ModelWithACL(
model_data_scientist = ModelWithOwner(
identifier="model-data-scientist",
provider_id="test_provider",
provider_resource_id="model-data-scientist",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["data-scientist", "researcher"], teams=["ml-team"]),
owner=User("testuser", {"roles": ["data-scientist", "researcher"], "teams": ["ml-team"]}),
)
await registry.register(model_public)
await registry.register(model_admin_only)
await registry.register(model_data_scientist)
mock_get_auth_attributes.return_value = {"roles": ["admin"], "teams": ["management"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["admin"], "teams": ["management"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
@ -75,7 +79,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["other-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["other-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 1
assert all_models.data[0].identifier == "model-public"
@ -86,7 +90,7 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
with pytest.raises(ValueError):
await routing_table.get_model("model-data-scientist")
mock_get_auth_attributes.return_value = {"roles": ["data-scientist"], "teams": ["ml-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["data-scientist"], "teams": ["ml-team"]})
all_models = await routing_table.list_models()
assert len(all_models.data) == 2
model_ids = [m.identifier for m in all_models.data]
@ -102,50 +106,62 @@ async def test_access_control_with_cache(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_and_updates(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_and_updates(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-updates",
provider_id="test_provider",
provider_resource_id="model-updates",
model_type=ModelType.llm,
)
await registry.register(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
model_public.access_attributes = AccessAttributes(roles=["admin"])
model_public.owner = User("testuser", {"roles": ["admin"]})
await registry.update(model_public)
mock_get_auth_attributes.return_value = {
"roles": ["user"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["user"],
},
)
with pytest.raises(ValueError):
await routing_table.get_model("model-updates")
mock_get_auth_attributes.return_value = {
"roles": ["admin"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["admin"],
},
)
model = await routing_table.get_model("model-updates")
assert model.identifier == "model-updates"
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_access_control_empty_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_control_empty_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model = ModelWithACL(
model = ModelWithOwner(
identifier="model-empty-attrs",
provider_id="test_provider",
provider_resource_id="model-empty-attrs",
model_type=ModelType.llm,
access_attributes=AccessAttributes(),
owner=User("testuser", {}),
)
await registry.register(model)
mock_get_auth_attributes.return_value = {
"roles": [],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": [],
},
)
result = await routing_table.get_model("model-empty-attrs")
assert result.identifier == "model-empty-attrs"
all_models = await routing_table.list_models()
@ -154,25 +170,25 @@ async def test_access_control_empty_attributes(mock_get_auth_attributes, test_se
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_no_user_attributes(mock_get_authenticated_user, test_setup):
registry, routing_table = test_setup
model_public = ModelWithACL(
model_public = ModelWithOwner(
identifier="model-public-2",
provider_id="test_provider",
provider_resource_id="model-public-2",
model_type=ModelType.llm,
)
model_restricted = ModelWithACL(
model_restricted = ModelWithOwner(
identifier="model-restricted",
provider_id="test_provider",
provider_resource_id="model-restricted",
model_type=ModelType.llm,
access_attributes=AccessAttributes(roles=["admin"]),
owner=User("testuser", {"roles": ["admin"]}),
)
await registry.register(model_public)
await registry.register(model_restricted)
mock_get_auth_attributes.return_value = None
mock_get_authenticated_user.return_value = User("test-user", None)
model = await routing_table.get_model("model-public-2")
assert model.identifier == "model-public-2"
@ -185,17 +201,17 @@ async def test_no_user_attributes(mock_get_auth_attributes, test_setup):
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_auth_attributes")
async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup):
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_automatic_access_attributes(mock_get_authenticated_user, test_setup):
"""Test that newly created resources inherit access attributes from their creator."""
registry, routing_table = test_setup
# Set creator's attributes
creator_attributes = {"roles": ["data-scientist"], "teams": ["ml-team"], "projects": ["llama-3"]}
mock_get_auth_attributes.return_value = creator_attributes
mock_get_authenticated_user.return_value = User("test-user", creator_attributes)
# Create model without explicit access attributes
model = ModelWithACL(
model = ModelWithOwner(
identifier="auto-access-model",
provider_id="test_provider",
provider_resource_id="auto-access-model",
@ -205,21 +221,346 @@ async def test_automatic_access_attributes(mock_get_auth_attributes, test_setup)
# Verify the model got creator's attributes
registered_model = await routing_table.get_model("auto-access-model")
assert registered_model.access_attributes is not None
assert registered_model.access_attributes.roles == ["data-scientist"]
assert registered_model.access_attributes.teams == ["ml-team"]
assert registered_model.access_attributes.projects == ["llama-3"]
assert registered_model.owner is not None
assert registered_model.owner.attributes is not None
assert registered_model.owner.attributes["roles"] == ["data-scientist"]
assert registered_model.owner.attributes["teams"] == ["ml-team"]
assert registered_model.owner.attributes["projects"] == ["llama-3"]
# Verify another user without matching attributes can't access it
mock_get_auth_attributes.return_value = {"roles": ["engineer"], "teams": ["infra-team"]}
mock_get_authenticated_user.return_value = User("test-user", {"roles": ["engineer"], "teams": ["infra-team"]})
with pytest.raises(ValueError):
await routing_table.get_model("auto-access-model")
# But a user with matching attributes can
mock_get_auth_attributes.return_value = {
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
}
mock_get_authenticated_user.return_value = User(
"test-user",
{
"roles": ["data-scientist", "engineer"],
"teams": ["ml-team", "platform-team"],
"projects": ["llama-3"],
},
)
model = await routing_table.get_model("auto-access-model")
assert model.identifier == "auto-access-model"
@pytest.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
mock_inference.__provider_spec__.api = Api.inference
mock_inference.register_model = AsyncMock(side_effect=_return_model)
mock_inference.unregister_model = AsyncMock(side_effect=_return_model)
config = """
- permit:
principal: user-1
actions: [create, read, delete]
description: user-1 has full access to all models
- permit:
principal: user-2
actions: [read]
resource: model::model-1
description: user-2 has read access to model-1 only
- permit:
principal: user-3
actions: [read]
resource: model::model-2
description: user-3 has read access to model-2 only
- forbid:
actions: [create, read, delete]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
routing_table = ModelsRoutingTable(
impls_by_provider_id={"test_provider": mock_inference},
dist_registry=cached_disk_dist_registry,
policy=policy,
)
yield routing_table
@pytest.mark.asyncio
@patch("llama_stack.distribution.routing_tables.common.get_authenticated_user")
async def test_access_policy(mock_get_authenticated_user, test_setup_with_access_policy):
routing_table = test_setup_with_access_policy
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.register_model("model-1", provider_id="test_provider")
await routing_table.register_model("model-2", provider_id="test_provider")
await routing_table.register_model("model-3", provider_id="test_provider")
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
model = await routing_table.get_model("model-3")
assert model.identifier == "model-3"
mock_get_authenticated_user.return_value = User(
"user-2",
{
"roles": ["user"],
"projects": ["foo"],
},
)
model = await routing_table.get_model("model-1")
assert model.identifier == "model-1"
with pytest.raises(ValueError):
await routing_table.get_model("model-2")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-4", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-1")
mock_get_authenticated_user.return_value = User(
"user-3",
{
"roles": ["user"],
"projects": ["bar"],
},
)
model = await routing_table.get_model("model-2")
assert model.identifier == "model-2"
with pytest.raises(ValueError):
await routing_table.get_model("model-1")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
with pytest.raises(AccessDeniedError):
await routing_table.register_model("model-5", provider_id="test_provider")
with pytest.raises(AccessDeniedError):
await routing_table.unregister_model("model-2")
mock_get_authenticated_user.return_value = User(
"user-1",
{
"roles": ["admin"],
"projects": ["foo", "bar"],
},
)
await routing_table.unregister_model("model-3")
with pytest.raises(ValueError):
await routing_table.get_model("model-3")
def test_permit_when():
config = """
- permit:
principal: user-1
actions: [read]
when: user in owners namespaces
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_permit_unless():
config = """
- permit:
principal: user-1
actions: [read]
resource: model::*
unless:
- user not in owners namespaces
- user in owners teams
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert not is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_when():
config = """
- forbid:
principal: user-1
actions: [read]
when:
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_forbid_unless():
config = """
- forbid:
principal: user-1
actions: [read]
unless:
user in owners namespaces
- permit:
actions: [read]
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("testuser", {"namespaces": ["foo"]}),
)
assert is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-1", {"namespaces": ["bar"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"namespaces": ["foo"]}))
def test_user_has_attribute():
config = """
- permit:
actions: [read]
when: user with admin in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_user_does_not_have_attribute():
config = """
- permit:
actions: [read]
unless: user with admin not in roles
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_owner():
config = """
- permit:
actions: [read]
when: user is owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_is_not_owner():
config = """
- permit:
actions: [read]
unless: user is not owner
"""
policy = TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
model = ModelWithOwner(
identifier="mymodel",
provider_id="myprovider",
model_type=ModelType.llm,
owner=User("user-2", {"namespaces": ["foo"]}),
)
assert not is_action_allowed(policy, "read", model, User("user-1", {"roles": ["basic"]}))
assert is_action_allowed(policy, "read", model, User("user-2", {"roles": ["admin"]}))
assert not is_action_allowed(policy, "read", model, User("user-3", {"namespaces": ["foo"]}))
assert not is_action_allowed(policy, "read", model, User("user-4", None))
def test_invalid_rule_permit_and_forbid_both_specified():
config = """
- permit:
actions: [read]
forbid:
actions: [create]
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_neither_permit_or_forbid_specified():
config = """
- when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_rule_when_and_unless_both_specified():
config = """
- permit:
actions: [read]
when: user is owner
unless: user with admin in roles
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
def test_invalid_condition():
config = """
- permit:
actions: [read]
when: random words that are not valid
"""
with pytest.raises(ValidationError):
TypeAdapter(list[AccessRule]).validate_python(yaml.safe_load(config))
@pytest.mark.parametrize(
"condition",
[
"user is owner",
"user is not owner",
"user with dev in teams",
"user with default not in namespaces",
"user in owners roles",
"user not in owners projects",
],
)
def test_condition_reprs(condition):
from llama_stack.distribution.access_control.conditions import parse_condition
assert condition == str(parse_condition(condition))

View file

@ -139,7 +139,7 @@ async def mock_post_success(*args, **kwargs):
{
"message": "Authentication successful",
"principal": "test-principal",
"access_attributes": {
"attributes": {
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
@ -233,7 +233,7 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
{
"message": "Authentication successful",
"principal": "test-principal",
"access_attributes": {
"attributes": {
"roles": ["admin", "user"],
"teams": ["ml-team", "nlp-team"],
"projects": ["llama-3", "project-x"],
@ -255,33 +255,6 @@ async def test_http_middleware_with_access_attributes(mock_http_middleware, mock
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
@pytest.mark.asyncio
async def test_http_middleware_no_attributes(mock_http_middleware, mock_scope):
"""Test middleware behavior with no access attributes"""
middleware, mock_app = mock_http_middleware
mock_receive = AsyncMock()
mock_send = AsyncMock()
with patch("httpx.AsyncClient") as mock_client:
mock_client_instance = AsyncMock()
mock_client.return_value.__aenter__.return_value = mock_client_instance
mock_client_instance.post.return_value = MockResponse(
200,
{
"message": "Authentication successful"
# No access_attributes
},
)
await middleware(mock_scope, mock_receive, mock_send)
assert "user_attributes" in mock_scope
attributes = mock_scope["user_attributes"]
assert "roles" in attributes
assert attributes["roles"] == ["test.jwt.token"]
# oauth2 token provider tests
@ -380,16 +353,16 @@ def test_get_attributes_from_claims():
"aud": "llama-stack",
}
attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"})
assert attributes.roles == ["my-user"]
assert attributes.teams == ["group1", "group2"]
assert attributes["roles"] == ["my-user"]
assert attributes["teams"] == ["group1", "group2"]
claims = {
"sub": "my-user",
"tenant": "my-tenant",
}
attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"})
assert attributes.roles == ["my-user"]
assert attributes.namespaces == ["my-tenant"]
assert attributes["roles"] == ["my-user"]
assert attributes["namespaces"] == ["my-tenant"]
claims = {
"sub": "my-user",
@ -408,9 +381,9 @@ def test_get_attributes_from_claims():
"groups": "teams",
},
)
assert set(attributes.roles) == {"my-user", "my-username"}
assert set(attributes.teams) == {"my-team", "group1", "group2"}
assert attributes.namespaces == ["my-tenant"]
assert set(attributes["roles"]) == {"my-user", "my-username"}
assert set(attributes["teams"]) == {"my-team", "group1", "group2"}
assert attributes["namespaces"] == ["my-tenant"]
# TODO: add more tests for oauth2 token provider

View file

@ -100,9 +100,10 @@ async def test_resolve_impls_basic():
add_protocol_methods(SampleImpl, Inference)
mock_module.get_provider_impl = AsyncMock(return_value=impl)
mock_module.get_provider_impl.__text_signature__ = "()"
sys.modules["test_module"] = mock_module
impls = await resolve_impls(run_config, provider_registry, dist_registry)
impls = await resolve_impls(run_config, provider_registry, dist_registry, policy={})
assert Api.inference in impls
assert isinstance(impls[Api.inference], InferenceRouter)

View file

@ -36,7 +36,7 @@ test_response_mcp_tool:
test_params:
case:
- case_id: "boiling_point_tool"
input: "What is the boiling point of polyjuice?"
input: "What is the boiling point of myawesomeliquid in Celsius?"
tools:
- type: mcp
server_label: "localmcp"
@ -94,3 +94,43 @@ test_response_multi_turn_image:
output: "llama"
- input: "What country do you find this animal primarily in? What continent?"
output: "peru"
test_response_multi_turn_tool_execution:
test_name: test_response_multi_turn_tool_execution
test_params:
case:
- case_id: "user_file_access_check"
input: "I need to check if user 'alice' can access the file 'document.txt'. First, get alice's user ID, then check if that user ID can access the file 'document.txt'. Do this as a series of steps, where each step is a separate message. Return only one tool call per step. Summarize the final result with a single 'yes' or 'no' response."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "yes"
- case_id: "experiment_results_lookup"
input: "I need to get the results for the 'boiling_point' experiment. First, get the experiment ID for 'boiling_point', then use that ID to get the experiment results. Tell me what you found."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
output: "100°C"
test_response_multi_turn_tool_execution_streaming:
test_name: test_response_multi_turn_tool_execution_streaming
test_params:
case:
- case_id: "user_permissions_workflow"
input: "Help me with this security check: First, get the user ID for 'charlie', then get the permissions for that user ID, and finally check if that user can access 'secret_file.txt'. Stream your progress as you work through each step."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "no"
- case_id: "experiment_analysis_streaming"
input: "I need a complete analysis: First, get the experiment ID for 'chemical_reaction', then get the results for that experiment, and tell me if the yield was above 80%. Please stream your analysis process."
tools:
- type: mcp
server_label: "localmcp"
server_url: "<FILLED_BY_TEST_RUNNER>"
stream: true
output: "85%"

View file

@ -12,7 +12,7 @@ import pytest
from llama_stack import LlamaStackAsLibraryClient
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from tests.common.mcp import make_mcp_server
from tests.common.mcp import dependency_tools, make_mcp_server
from tests.verifications.openai_api.fixtures.fixtures import (
case_id_generator,
get_base_test_name,
@ -280,6 +280,7 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
tools=tools,
stream=False,
)
assert len(response.output) >= 3
list_tools = response.output[0]
assert list_tools.type == "mcp_list_tools"
@ -290,11 +291,12 @@ def test_response_non_streaming_mcp_tool(request, openai_client, model, provider
call = response.output[1]
assert call.type == "mcp_call"
assert call.name == "get_boiling_point"
assert json.loads(call.arguments) == {"liquid_name": "polyjuice", "celcius": True}
assert json.loads(call.arguments) == {"liquid_name": "myawesomeliquid", "celsius": True}
assert call.error is None
assert "-100" in call.output
message = response.output[2]
# sometimes the model will call the tool again, so we need to get the last message
message = response.output[-1]
text_content = message.content[0].text
assert "boiling point" in text_content.lower()
@ -393,3 +395,190 @@ def test_response_non_streaming_multi_turn_image(request, openai_client, model,
previous_response_id = response.id
output_text = response.output_text.lower()
assert turn["output"].lower() in output_text
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution"]["test_params"]["case"],
ids=case_id_generator,
)
def test_response_non_streaming_multi_turn_tool_execution(
request, openai_client, model, provider, verification_config, case
):
"""Test multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
response = openai_client.responses.create(
input=case["input"],
model=model,
tools=tools,
)
# Verify we have MCP tool calls in the output
mcp_list_tools = [output for output in response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in response.output if output.type == "mcp_call"]
message_outputs = [output for output in response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
final_message = message_outputs[-1]
assert final_message.role == "assistant", f"Final message should be from assistant, got {final_message.role}"
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
expected_output = case["output"]
assert expected_output.lower() in response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {response.output_text}"
)
@pytest.mark.parametrize(
"case",
responses_test_cases["test_response_multi_turn_tool_execution_streaming"]["test_params"]["case"],
ids=case_id_generator,
)
async def test_response_streaming_multi_turn_tool_execution(
request, openai_client, model, provider, verification_config, case
):
"""Test streaming multi-turn tool execution where multiple MCP tool calls are performed in sequence."""
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with make_mcp_server(tools=dependency_tools()) as mcp_server_info:
tools = case["tools"]
# Replace the placeholder URL with the actual server URL
for tool in tools:
if tool["type"] == "mcp" and tool["server_url"] == "<FILLED_BY_TEST_RUNNER>":
tool["server_url"] = mcp_server_info["server_url"]
stream = openai_client.responses.create(
input=case["input"],
model=model,
tools=tools,
stream=True,
)
chunks = []
async for chunk in stream:
chunks.append(chunk)
# Should have at least response.created and response.completed
assert len(chunks) >= 2, f"Expected at least 2 chunks (created + completed), got {len(chunks)}"
# First chunk should be response.created
assert chunks[0].type == "response.created", f"First chunk should be response.created, got {chunks[0].type}"
# Last chunk should be response.completed
assert chunks[-1].type == "response.completed", (
f"Last chunk should be response.completed, got {chunks[-1].type}"
)
# Get the final response from the last chunk
final_chunk = chunks[-1]
if hasattr(final_chunk, "response"):
final_response = final_chunk.response
# Verify multi-turn MCP tool execution results
mcp_list_tools = [output for output in final_response.output if output.type == "mcp_list_tools"]
mcp_calls = [output for output in final_response.output if output.type == "mcp_call"]
message_outputs = [output for output in final_response.output if output.type == "message"]
# Should have exactly 1 MCP list tools message (at the beginning)
assert len(mcp_list_tools) == 1, f"Expected exactly 1 mcp_list_tools, got {len(mcp_list_tools)}"
assert mcp_list_tools[0].server_label == "localmcp"
assert len(mcp_list_tools[0].tools) == 5 # Updated for dependency tools
expected_tool_names = {
"get_user_id",
"get_user_permissions",
"check_file_access",
"get_experiment_id",
"get_experiment_results",
}
assert {t["name"] for t in mcp_list_tools[0].tools} == expected_tool_names
# Should have at least 1 MCP call (the model should call at least one tool)
assert len(mcp_calls) >= 1, f"Expected at least 1 mcp_call, got {len(mcp_calls)}"
# All MCP calls should be completed (verifies our tool execution works)
for mcp_call in mcp_calls:
assert mcp_call.error is None, f"MCP call should not have errors, got: {mcp_call.error}"
# Should have at least one final message response
assert len(message_outputs) >= 1, f"Expected at least 1 message output, got {len(message_outputs)}"
# Final message should be from assistant and completed
final_message = message_outputs[-1]
assert final_message.role == "assistant", (
f"Final message should be from assistant, got {final_message.role}"
)
assert final_message.status == "completed", f"Final message should be completed, got {final_message.status}"
assert len(final_message.content) > 0, "Final message should have content"
# Check that the expected output appears in the response
expected_output = case["output"]
assert expected_output.lower() in final_response.output_text.lower(), (
f"Expected '{expected_output}' to appear in response: {final_response.output_text}"
)
@pytest.mark.parametrize(
"text_format",
# Not testing json_object because most providers don't actually support it.
[
{"type": "text"},
{
"type": "json_schema",
"name": "capitals",
"description": "A schema for the capital of each country",
"schema": {"type": "object", "properties": {"capital": {"type": "string"}}},
"strict": True,
},
],
)
def test_response_text_format(request, openai_client, model, provider, verification_config, text_format):
if isinstance(openai_client, LlamaStackAsLibraryClient):
pytest.skip("Responses API text format is not yet supported in library client.")
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
stream = False
response = openai_client.responses.create(
model=model,
input="What is the capital of France?",
stream=stream,
text={"format": text_format},
)
# by_alias=True is needed because otherwise Pydantic renames our "schema" field
assert response.text.format.model_dump(exclude_none=True, by_alias=True) == text_format
assert "paris" in response.output_text.lower()
if text_format["type"] == "json_schema":
assert "paris" in json.loads(response.output_text)["capital"].lower()