Resolve merge conflict in server.py

This commit is contained in:
ehhuang 2025-07-08 00:27:58 -07:00
commit 9ece598705
173 changed files with 2655 additions and 10307 deletions

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path):
yield kvstore
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()
yield registry
@pytest.fixture(scope="function")
@pytest_asyncio.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize()

View file

@ -9,6 +9,7 @@ from datetime import datetime
from unittest.mock import patch
import pytest
import pytest_asyncio
from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason
@ -16,7 +17,7 @@ from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest.fixture
@pytest_asyncio.fixture
async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence

View file

@ -10,7 +10,7 @@ import pytest
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_doc
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
@pytest.mark.asyncio
@ -143,3 +143,45 @@ async def test_content_from_doc_with_interleaved_content():
assert result == "First item\nSecond item"
mock_interleaved.assert_called_once_with(interleaved_content)
def test_content_from_data_and_mime_type_success_utf8():
"""Test successful decoding with UTF-8 encoding."""
data = "Hello World! 🌍".encode()
mime_type = "text/plain"
with patch("chardet.detect") as mock_detect:
mock_detect.return_value = {"encoding": "utf-8"}
result = content_from_data_and_mime_type(data, mime_type)
mock_detect.assert_called_once_with(data)
assert result == "Hello World! 🌍"
def test_content_from_data_and_mime_type_error_win1252():
"""Test fallback to UTF-8 when Windows-1252 encoding detection fails."""
data = "Hello World! 🌍".encode()
mime_type = "text/plain"
with patch("chardet.detect") as mock_detect:
mock_detect.return_value = {"encoding": "Windows-1252"}
result = content_from_data_and_mime_type(data, mime_type)
assert result == "Hello World! 🌍"
mock_detect.assert_called_once_with(data)
def test_content_from_data_and_mime_type_both_encodings_fail():
"""Test that exceptions are raised when both primary and UTF-8 encodings fail."""
# Create invalid byte sequence that fails with both encodings
data = b"\xff\xfe\x00\x8f" # Invalid UTF-8 sequence
mime_type = "text/plain"
with patch("chardet.detect") as mock_detect:
mock_detect.return_value = {"encoding": "windows-1252"}
# Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type)

View file

@ -32,6 +32,14 @@ def test_generate_chunk_id():
]
def test_generate_chunk_id_with_window():
chunk = Chunk(content="test", metadata={"document_id": "doc-1"})
chunk_id1 = generate_chunk_id("doc-1", chunk, chunk_window="0-1")
chunk_id2 = generate_chunk_id("doc-1", chunk, chunk_window="1-2")
assert chunk_id1 == "149018fe-d0eb-0f8d-5f7f-726bdd2aeedb"
assert chunk_id2 == "4562c1ee-9971-1f3b-51a6-7d05e5211154"
def test_chunk_id():
# Test with existing chunk ID
chunk_with_id = Chunk(content="test", metadata={"document_id": "existing-id"})

View file

@ -148,7 +148,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest.fixture(scope="session")
@pytest_asyncio.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)

View file

@ -7,6 +7,7 @@
from unittest.mock import MagicMock, Mock, patch
import pytest
import pytest_asyncio
import yaml
from pydantic import TypeAdapter, ValidationError
@ -26,7 +27,7 @@ def _return_model(model):
return model
@pytest.fixture
@pytest_asyncio.fixture
async def test_setup(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()
@ -245,7 +246,7 @@ async def test_automatic_access_attributes(mock_get_authenticated_user, test_set
assert model.identifier == "auto-access-model"
@pytest.fixture
@pytest_asyncio.fixture
async def test_setup_with_access_policy(cached_disk_dist_registry):
mock_inference = Mock()
mock_inference.__provider_spec__ = MagicMock()

View file

@ -0,0 +1,187 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import Mock
from fastapi import HTTPException
from openai import BadRequestError
from pydantic import ValidationError
from llama_stack.distribution.access_control.access_control import AccessDeniedError
from llama_stack.distribution.datatypes import AuthenticationRequiredError
from llama_stack.distribution.server.server import translate_exception
class TestTranslateException:
"""Test cases for the translate_exception function."""
def test_translate_access_denied_error(self):
"""Test that AccessDeniedError is translated to 403 HTTP status."""
exc = AccessDeniedError()
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert result.detail == "Permission denied: Insufficient permissions"
def test_translate_access_denied_error_with_context(self):
"""Test that AccessDeniedError with context includes detailed information."""
from llama_stack.distribution.datatypes import User
# Create mock user and resource
user = User("test-user", {"roles": ["user"], "teams": ["dev"]})
# Create a simple mock object that implements the ProtectedResource protocol
class MockResource:
def __init__(self, type: str, identifier: str, owner=None):
self.type = type
self.identifier = identifier
self.owner = owner
resource = MockResource("vector_db", "test-db")
exc = AccessDeniedError("create", resource, user)
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert "test-user" in result.detail
assert "vector_db::test-db" in result.detail
assert "create" in result.detail
assert "roles=['user']" in result.detail
assert "teams=['dev']" in result.detail
def test_translate_permission_error(self):
"""Test that PermissionError is translated to 403 HTTP status."""
exc = PermissionError("Permission denied")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 403
assert result.detail == "Permission denied: Permission denied"
def test_translate_value_error(self):
"""Test that ValueError is translated to 400 HTTP status."""
exc = ValueError("Invalid input")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert result.detail == "Invalid value: Invalid input"
def test_translate_bad_request_error(self):
"""Test that BadRequestError is translated to 400 HTTP status."""
# Create a mock response for BadRequestError
mock_response = Mock()
mock_response.status_code = 400
mock_response.headers = {}
exc = BadRequestError("Bad request", response=mock_response, body="Bad request")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert result.detail == "Bad request"
def test_translate_authentication_required_error(self):
"""Test that AuthenticationRequiredError is translated to 401 HTTP status."""
exc = AuthenticationRequiredError("Authentication required")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 401
assert result.detail == "Authentication required: Authentication required"
def test_translate_timeout_error(self):
"""Test that TimeoutError is translated to 504 HTTP status."""
exc = TimeoutError("Operation timed out")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 504
assert result.detail == "Operation timed out: Operation timed out"
def test_translate_asyncio_timeout_error(self):
"""Test that asyncio.TimeoutError is translated to 504 HTTP status."""
exc = TimeoutError()
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 504
assert result.detail == "Operation timed out: "
def test_translate_not_implemented_error(self):
"""Test that NotImplementedError is translated to 501 HTTP status."""
exc = NotImplementedError("Not implemented")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 501
assert result.detail == "Not implemented: Not implemented"
def test_translate_validation_error(self):
"""Test that ValidationError is translated to 400 HTTP status with proper format."""
# Create a mock validation error using proper Pydantic error format
exc = ValidationError.from_exception_data(
"TestModel",
[
{
"loc": ("field", "nested"),
"msg": "field required",
"type": "missing",
}
],
)
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 400
assert "errors" in result.detail
assert len(result.detail["errors"]) == 1
assert result.detail["errors"][0]["loc"] == ["field", "nested"]
assert result.detail["errors"][0]["msg"] == "Field required"
assert result.detail["errors"][0]["type"] == "missing"
def test_translate_generic_exception(self):
"""Test that generic exceptions are translated to 500 HTTP status."""
exc = Exception("Unexpected error")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 500
assert result.detail == "Internal server error: An unexpected error occurred."
def test_translate_runtime_error(self):
"""Test that RuntimeError is translated to 500 HTTP status."""
exc = RuntimeError("Runtime error")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 500
assert result.detail == "Internal server error: An unexpected error occurred."
def test_multiple_access_denied_scenarios(self):
"""Test various scenarios that should result in 403 status codes."""
# Test AccessDeniedError (uses enhanced message)
exc1 = AccessDeniedError()
result1 = translate_exception(exc1)
assert isinstance(result1, HTTPException)
assert result1.status_code == 403
assert result1.detail == "Permission denied: Insufficient permissions"
# Test PermissionError (uses generic message)
exc2 = PermissionError("No permission")
result2 = translate_exception(exc2)
assert isinstance(result2, HTTPException)
assert result2.status_code == 403
assert result2.detail == "Permission denied: No permission"
exc3 = PermissionError("Access denied")
result3 = translate_exception(exc3)
assert isinstance(result3, HTTPException)
assert result3.status_code == 403
assert result3.detail == "Permission denied: Access denied"

View file

@ -104,19 +104,17 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
# Test scenarios with different access control patterns
test_scenarios = [
# Scenario 1: Public record (no access control)
# Scenario 1: Public record (no access control - represents None user insert)
{"id": "1", "name": "public", "access_attributes": None},
# Scenario 2: Empty access control (should be treated as public)
{"id": "2", "name": "empty", "access_attributes": {}},
# Scenario 3: Record with roles requirement
{"id": "3", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
# Scenario 4: Record with multiple attribute categories
{"id": "4", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# Scenario 5: Record with teams only (missing roles category)
{"id": "5", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 6: Record with roles and projects
# Scenario 2: Record with roles requirement
{"id": "2", "name": "admin-only", "access_attributes": {"roles": ["admin"]}},
# Scenario 3: Record with multiple attribute categories
{"id": "3", "name": "admin-ml-team", "access_attributes": {"roles": ["admin"], "teams": ["ml-team"]}},
# Scenario 4: Record with teams only (missing roles category)
{"id": "4", "name": "ml-team-only", "access_attributes": {"teams": ["ml-team"]}},
# Scenario 5: Record with roles and projects
{
"id": "6",
"id": "5",
"name": "admin-project-x",
"access_attributes": {"roles": ["admin"], "projects": ["project-x"]},
},