Merge branch 'main' into openai-mixin

This commit is contained in:
Matthew Farrellee 2025-07-22 10:27:10 -04:00
commit 0ff9ae01a0
21 changed files with 900 additions and 214 deletions

View file

@ -50,6 +50,7 @@ def setup_telemetry_data(llama_stack_client, text_model_id):
yield
@pytest.mark.skip(reason="Skipping telemetry tests for now")
def test_query_traces_basic(llama_stack_client):
"""Test basic trace querying functionality with proper data validation."""
all_traces = llama_stack_client.telemetry.query_traces(limit=5)
@ -105,6 +106,7 @@ def test_query_traces_basic(llama_stack_client):
assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id"
@pytest.mark.skip(reason="Skipping telemetry tests for now")
def test_query_spans_basic(llama_stack_client):
"""Test basic span querying functionality with proper validation."""
spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[])
@ -153,6 +155,7 @@ def test_query_spans_basic(llama_stack_client):
assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}"
@pytest.mark.skip(reason="Skipping telemetry tests for now")
def test_telemetry_pagination(llama_stack_client):
"""Test pagination in telemetry queries."""
# Get total count of traces

View file

@ -11,17 +11,15 @@ from unittest.mock import AsyncMock
from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.models import Model
from llama_stack.apis.shields.shields import Shield
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
class Impl:
@ -54,17 +52,6 @@ class SafetyImpl(Impl):
return shield
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
class DatasetsImpl(Impl):
def __init__(self):
super().__init__(Api.datasetio)
@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_datasets_routing_table(cached_disk_dist_registry):
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
await table.initialize()

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,274 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Unit tests for the routing tables vector_dbs
import time
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.datatypes import Api
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
from llama_stack.apis.vector_io.vector_io import (
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
from llama_stack.distribution.access_control.datatypes import AccessRule, Scope
from llama_stack.distribution.datatypes import User
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_retrieve_vector_store(self, vector_store_id):
return VectorStoreObject(
id=vector_store_id,
name="Test Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_update_vector_store(self, vector_store_id, **kwargs):
return VectorStoreObject(
id=vector_store_id,
name="Updated Store",
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def openai_delete_vector_store(self, vector_store_id):
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
return VectorStoreSearchResponsePage(
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
)
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
return [
VectorStoreFileObject(
id="1",
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
]
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
return VectorStoreFileContentsResponse(
file_id=file_id,
filename="Sample File name",
attributes={"key": "value"},
content=[VectorStoreContent(type="text", text="Sample content")],
)
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
return VectorStoreFileObject(
id=file_id,
status="completed",
chunking_strategy={"type": "auto"},
created_at=int(time.time()),
vector_store_id=vector_store_id,
)
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
authorized_table = "vs1"
authorized_team = "team1"
unauthorized_team = "team2"
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
# Authorized reader
with request_provider_data_context({}, authorized_user):
res = await table.openai_retrieve_vector_store(authorized_table)
assert res == "OK"
# Authorized updater
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
assert res == "UPDATED"
# Unauthorized reader
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_retrieve_vector_store(authorized_table)
# Unauthorized updater
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
# Authorized deleter
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
with request_provider_data_context({}, authorized_user):
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
assert res == "DELETED"
# Unauthorized deleter
with request_provider_data_context({}, unauthorized_user):
with pytest.raises(ValueError):
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
impl = VectorDBImpl()
policy = [
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
]
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
vector_db_id = "vs1"
file_id = "file-1"
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
]
update_methods = [
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
]
delete_methods = [
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
(table.openai_delete_vector_store, (vector_db_id,), {}),
]
for user in [admin_user, read_only_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in read_methods:
result = await method(*args, **kwargs)
assert result is not None, f"Read operation failed with user {user.principal}"
with request_provider_data_context({}, no_access_user):
for method, args, kwargs in read_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)
with request_provider_data_context({}, admin_user):
for method, args, kwargs in update_methods:
result = await method(*args, **kwargs)
assert result is not None, "Update operation failed with admin user"
with request_provider_data_context({}, admin_user):
for method, args, kwargs in delete_methods:
result = await method(*args, **kwargs)
assert result is not None, "Delete operation failed with admin user"
for user in [read_only_user, no_access_user]:
with request_provider_data_context({}, user):
for method, args, kwargs in delete_methods:
with pytest.raises(ValueError):
await method(*args, **kwargs)

View file

@ -0,0 +1,176 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from llama_stack.apis.agents import Document
from llama_stack.apis.common.content_types import URL, TextContentItem
from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text
async def test_get_raw_document_text_supports_text_mime_types():
"""Test that the function accepts text/* mime types."""
document = Document(content="Sample text content", mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Sample text content"
async def test_get_raw_document_text_supports_yaml_mime_type():
"""Test that the function accepts application/yaml mime type."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning():
"""Test that the function accepts text/yaml but emits a deprecation warning."""
yaml_content = """
name: test
version: 1.0
items:
- item1
- item2
"""
document = Document(content=yaml_content, mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that exactly one warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
assert "application/yaml" in str(w[0].message)
assert "deprecated" in str(w[0].message).lower()
async def test_get_raw_document_text_deprecated_text_yaml_with_url():
"""Test that text/yaml works with URL content and emits warning."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item():
"""Test that text/yaml works with TextContentItem and emits warning."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
result = await get_raw_document_text(document)
# Check that result is correct
assert result == yaml_content
# Check that deprecation warning was issued
assert len(w) == 1
assert issubclass(w[0].category, DeprecationWarning)
assert "text/yaml" in str(w[0].message)
async def test_get_raw_document_text_rejects_unsupported_mime_types():
"""Test that the function rejects unsupported mime types."""
document = Document(
content="Some content",
mime_type="application/json", # Not supported
)
with pytest.raises(ValueError, match="Unexpected document mime type: application/json"):
await get_raw_document_text(document)
async def test_get_raw_document_text_with_url_content():
"""Test that the function handles URL content correctly."""
mock_response = AsyncMock()
mock_response.text = "Content from URL"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = "Content from URL"
document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Content from URL"
mock_load.assert_called_once_with("https://example.com/test.txt")
async def test_get_raw_document_text_with_yaml_url():
"""Test that the function handles YAML URLs correctly."""
yaml_content = "name: test\nversion: 1.0"
with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load:
mock_load.return_value = yaml_content
document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
mock_load.assert_called_once_with("https://example.com/config.yaml")
async def test_get_raw_document_text_with_text_content_item():
"""Test that the function handles TextContentItem correctly."""
document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain")
result = await get_raw_document_text(document)
assert result == "Text content item"
async def test_get_raw_document_text_with_yaml_text_content_item():
"""Test that the function handles YAML TextContentItem correctly."""
yaml_content = "key: value\nlist:\n - item1\n - item2"
document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml")
result = await get_raw_document_text(document)
assert result == yaml_content
async def test_get_raw_document_text_rejects_unexpected_content_type():
"""Test that the function rejects unexpected document content types."""
# Create a mock document that bypasses Pydantic validation
mock_document = MagicMock(spec=Document)
mock_document.mime_type = "text/plain"
mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem)
with pytest.raises(ValueError, match="Unexpected document content type: <class 'int'>"):
await get_raw_document_text(mock_document)