fix: ABAC bypass in vector store operations (#4394)

Vector store operations were bypassing ABAC checks by calling providers
directly instead of going through the routing table. This allowed
unauthorized access to vector store data and operations.

Changes:
o Route all VectorIORouter methods through routing table instead of
  directly to providers
o Update routing table to enforce ABAC checks on all vector store
  operations (read, update, delete)
o Add test suite verifying ABAC enforcement for all vector store
  operations
o Ensure providers are never called when authorization fails

Fixes security issue where users could access vector stores they don't
have permission for.

Fixes: #4393

Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
Derek Higgins 2025-12-16 18:49:16 +00:00 committed by GitHub
parent 401d3b8ce6
commit 5abb7df41a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 429 additions and 73 deletions

View file

@ -132,8 +132,7 @@ class VectorIORouter(VectorIO):
f"VectorIORouter.insert_chunks: {vector_store_id}, {len(chunks)} chunks, " f"VectorIORouter.insert_chunks: {vector_store_id}, {len(chunks)} chunks, "
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}" f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
) )
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.insert_chunks(vector_store_id, chunks, ttl_seconds)
return await provider.insert_chunks(vector_store_id, chunks, ttl_seconds)
async def query_chunks( async def query_chunks(
self, self,
@ -142,8 +141,7 @@ class VectorIORouter(VectorIO):
params: dict[str, Any] | None = None, params: dict[str, Any] | None = None,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}") logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.query_chunks(vector_store_id, query, params)
return await provider.query_chunks(vector_store_id, query, params)
# OpenAI Vector Stores API endpoints # OpenAI Vector Stores API endpoints
async def openai_create_vector_store( async def openai_create_vector_store(
@ -248,9 +246,8 @@ class VectorIORouter(VectorIO):
all_stores = [] all_stores = []
for vector_store in vector_stores: for vector_store in vector_stores:
try: try:
provider = await self.routing_table.get_provider_impl(vector_store.identifier) vector_store_obj = await self.routing_table.openai_retrieve_vector_store(vector_store.identifier)
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier) all_stores.append(vector_store_obj)
all_stores.append(vector_store)
except Exception as e: except Exception as e:
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}") logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
continue continue
@ -292,8 +289,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store( async def openai_update_vector_store(
self, self,
@ -310,8 +306,7 @@ class VectorIORouter(VectorIO):
if current_store and current_store.provider_id != metadata["provider_id"]: if current_store and current_store.provider_id != metadata["provider_id"]:
raise ValueError("provider_id cannot be changed after vector store creation") raise ValueError("provider_id cannot be changed after vector store creation")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_update_vector_store(
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
name=name, name=name,
expires_after=expires_after, expires_after=expires_after,
@ -346,8 +341,7 @@ class VectorIORouter(VectorIO):
original_query = query original_query = query
search_query = await self._rewrite_query_for_search(original_query) search_query = await self._rewrite_query_for_search(original_query)
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_search_vector_store(
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
query=search_query, query=search_query,
filters=filters, filters=filters,
@ -367,8 +361,7 @@ class VectorIORouter(VectorIO):
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
if chunking_strategy is None or chunking_strategy.type == "auto": if chunking_strategy is None or chunking_strategy.type == "auto":
chunking_strategy = VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig()) chunking_strategy = VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig())
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_attach_file_to_vector_store(
return await provider.openai_attach_file_to_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -385,8 +378,7 @@ class VectorIORouter(VectorIO):
filter: VectorStoreFileStatus | None = None, filter: VectorStoreFileStatus | None = None,
) -> list[VectorStoreFileObject]: ) -> list[VectorStoreFileObject]:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_list_files_in_vector_store(
return await provider.openai_list_files_in_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
limit=limit, limit=limit,
order=order, order=order,
@ -401,8 +393,7 @@ class VectorIORouter(VectorIO):
file_id: str, file_id: str,
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_retrieve_vector_store_file(
return await provider.openai_retrieve_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -433,8 +424,7 @@ class VectorIORouter(VectorIO):
attributes: dict[str, Any], attributes: dict[str, Any],
) -> VectorStoreFileObject: ) -> VectorStoreFileObject:
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_update_vector_store_file(
return await provider.openai_update_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
attributes=attributes, attributes=attributes,
@ -446,8 +436,7 @@ class VectorIORouter(VectorIO):
file_id: str, file_id: str,
) -> VectorStoreFileDeleteResponse: ) -> VectorStoreFileDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_delete_vector_store_file(
return await provider.openai_delete_vector_store_file(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_id=file_id, file_id=file_id,
) )
@ -483,8 +472,10 @@ class VectorIORouter(VectorIO):
logger.debug( logger.debug(
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files" f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
) )
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_create_vector_store_file_batch(
return await provider.openai_create_vector_store_file_batch(vector_store_id, params) vector_store_id=vector_store_id,
params=params,
)
async def openai_retrieve_vector_store_file_batch( async def openai_retrieve_vector_store_file_batch(
self, self,
@ -492,8 +483,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreFileBatchObject: ) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}") logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_retrieve_vector_store_file_batch(
return await provider.openai_retrieve_vector_store_file_batch(
batch_id=batch_id, batch_id=batch_id,
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
) )
@ -509,8 +499,7 @@ class VectorIORouter(VectorIO):
order: str | None = "desc", order: str | None = "desc",
) -> VectorStoreFilesListInBatchResponse: ) -> VectorStoreFilesListInBatchResponse:
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}") logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_list_files_in_vector_store_file_batch(
return await provider.openai_list_files_in_vector_store_file_batch(
batch_id=batch_id, batch_id=batch_id,
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
after=after, after=after,
@ -526,8 +515,7 @@ class VectorIORouter(VectorIO):
vector_store_id: str, vector_store_id: str,
) -> VectorStoreFileBatchObject: ) -> VectorStoreFileBatchObject:
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}") logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
provider = await self.routing_table.get_provider_impl(vector_store_id) return await self.routing_table.openai_cancel_vector_store_file_batch(
return await provider.openai_cancel_vector_store_file_batch(
batch_id=batch_id, batch_id=batch_id,
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
) )

View file

@ -13,9 +13,13 @@ from llama_stack.log import get_logger
# Removed VectorStores import to avoid exposing public API # Removed VectorStores import to avoid exposing public API
from llama_stack_api import ( from llama_stack_api import (
Chunk,
InterleavedContent,
ModelNotFoundError, ModelNotFoundError,
ModelType, ModelType,
ModelTypeError, ModelTypeError,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
QueryChunksResponse,
ResourceType, ResourceType,
SearchRankingOptions, SearchRankingOptions,
VectorStoreChunkingStrategy, VectorStoreChunkingStrategy,
@ -87,6 +91,26 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
await self.register_object(vector_store) await self.register_object(vector_store)
return vector_store return vector_store
async def insert_chunks(
self,
vector_store_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.insert_chunks(vector_store_id, chunks, ttl_seconds)
async def query_chunks(
self,
vector_store_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
await self.assert_action_allowed("read", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id)
return await provider.query_chunks(vector_store_id, query, params)
async def openai_retrieve_vector_store( async def openai_retrieve_vector_store(
self, self,
vector_store_id: str, vector_store_id: str,
@ -142,20 +166,6 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
search_mode: str | None = "vector", search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage: ) -> VectorStoreSearchResponsePage:
await self.assert_action_allowed("read", "vector_store", vector_store_id) await self.assert_action_allowed("read", "vector_store", vector_store_id)
# Delegate to VectorIORouter if available (which handles query rewriting)
if self.vector_io_router is not None:
return await self.vector_io_router.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
search_mode=search_mode,
)
# Fallback to direct provider call if VectorIORouter not available
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store( return await provider.openai_search_vector_store(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
@ -261,17 +271,13 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
async def openai_create_vector_store_file_batch( async def openai_create_vector_store_file_batch(
self, self,
vector_store_id: str, vector_store_id: str,
file_ids: list[str], params: OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
attributes: dict[str, Any] | None = None,
chunking_strategy: Any | None = None,
): ):
await self.assert_action_allowed("update", "vector_store", vector_store_id) await self.assert_action_allowed("update", "vector_store", vector_store_id)
provider = await self.get_provider_impl(vector_store_id) provider = await self.get_provider_impl(vector_store_id)
return await provider.openai_create_vector_store_file_batch( return await provider.openai_create_vector_store_file_batch(
vector_store_id=vector_store_id, vector_store_id=vector_store_id,
file_ids=file_ids, params=params,
attributes=attributes,
chunking_strategy=chunking_strategy,
) )
async def openai_retrieve_vector_store_file_batch( async def openai_retrieve_vector_store_file_batch(

View file

@ -105,9 +105,7 @@ async def test_update_vector_store_same_provider_id_succeeds():
mock_existing_store.identifier = "vs_123" mock_existing_store.identifier = "vs_123"
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store) mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
mock_routing_table.get_provider_impl = AsyncMock( mock_routing_table.openai_update_vector_store = AsyncMock(return_value=Mock(identifier="vs_123"))
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
)
router = VectorIORouter(mock_routing_table) router = VectorIORouter(mock_routing_table)
@ -118,10 +116,8 @@ async def test_update_vector_store_same_provider_id_succeeds():
metadata={"provider_id": "inline::faiss"}, # Same provider_id metadata={"provider_id": "inline::faiss"}, # Same provider_id
) )
# Verify the provider update method was called # Verify the routing table method was called
mock_routing_table.get_provider_impl.assert_called_once_with("vs_123") mock_routing_table.openai_update_vector_store.assert_called_once_with(
provider = await mock_routing_table.get_provider_impl("vs_123")
provider.openai_update_vector_store.assert_called_once_with(
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"} vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
) )
@ -165,11 +161,9 @@ async def test_query_rewrite_functionality():
mock_routing_table = Mock() mock_routing_table = Mock()
# Mock provider that returns search results # Mock routing table method that returns search results
mock_provider = Mock()
mock_search_response = VectorStoreSearchResponsePage(search_query=["rewritten test query"], data=[], has_more=False) mock_search_response = VectorStoreSearchResponsePage(search_query=["rewritten test query"], data=[], has_more=False)
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response) mock_routing_table.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
# Mock inference API for query rewriting # Mock inference API for query rewriting
mock_inference_api = Mock() mock_inference_api = Mock()
@ -206,9 +200,9 @@ async def test_query_rewrite_functionality():
expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query") expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query")
assert prompt_text == expected_prompt assert prompt_text == expected_prompt
# Verify provider was called with rewritten query and rewrite_query=False # Verify routing table was called with rewritten query and rewrite_query=False
mock_provider.openai_search_vector_store.assert_called_once() mock_routing_table.openai_search_vector_store.assert_called_once()
call_kwargs = mock_provider.openai_search_vector_store.call_args.kwargs call_kwargs = mock_routing_table.openai_search_vector_store.call_args.kwargs
assert call_kwargs["query"] == "rewritten test query" assert call_kwargs["query"] == "rewritten test query"
assert call_kwargs["rewrite_query"] is False # Should be False since router handled it assert call_kwargs["rewrite_query"] is False # Should be False since router handled it
@ -242,10 +236,8 @@ async def test_query_rewrite_with_custom_prompt():
mock_routing_table = Mock() mock_routing_table = Mock()
mock_provider = Mock()
mock_search_response = VectorStoreSearchResponsePage(search_query=["custom rewrite"], data=[], has_more=False) mock_search_response = VectorStoreSearchResponsePage(search_query=["custom rewrite"], data=[], has_more=False)
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response) mock_routing_table.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
mock_inference_api = Mock() mock_inference_api = Mock()
mock_inference_api.openai_chat_completion = AsyncMock( mock_inference_api.openai_chat_completion = AsyncMock(
@ -283,10 +275,8 @@ async def test_search_without_rewrite():
mock_routing_table = Mock() mock_routing_table = Mock()
mock_provider = Mock()
mock_search_response = VectorStoreSearchResponsePage(search_query=["test query"], data=[], has_more=False) mock_search_response = VectorStoreSearchResponsePage(search_query=["test query"], data=[], has_more=False)
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response) mock_routing_table.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
mock_inference_api = Mock() mock_inference_api = Mock()
mock_inference_api.openai_chat_completion = AsyncMock() mock_inference_api.openai_chat_completion = AsyncMock()
@ -303,6 +293,6 @@ async def test_search_without_rewrite():
# Verify inference API was NOT called # Verify inference API was NOT called
assert not mock_inference_api.openai_chat_completion.called assert not mock_inference_api.openai_chat_completion.called
# Verify provider was called with original query # Verify routing table was called with original query
call_kwargs = mock_provider.openai_search_vector_store.call_args.kwargs call_kwargs = mock_routing_table.openai_search_vector_store.call_args.kwargs
assert call_kwargs["query"] == "test query" assert call_kwargs["query"] == "test query"

View file

@ -0,0 +1,372 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Tests for ABAC enforcement in vector store operations.
This test suite verifies that all vector store operations properly enforce
authorization checks through the router -> routing table -> ABAC flow.
"""
from unittest.mock import AsyncMock, Mock
import pytest
from llama_stack.core.routers.vector_io import VectorIORouter
from llama_stack.core.routing_tables.vector_stores import VectorStoresRoutingTable
from llama_stack_api import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
QueryChunksResponse,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
VectorStoreDeleteResponse,
VectorStoreFileBatchObject,
VectorStoreFileCounts,
VectorStoreFileDeleteResponse,
VectorStoreFileObject,
VectorStoreFilesListInBatchResponse,
VectorStoreListFilesResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
)
class MockDistRegistry:
"""Mock distribution registry for testing."""
def __init__(self):
self.dist = None
@pytest.fixture
def mock_provider():
"""Create a mock provider that returns valid responses for all operations."""
provider = Mock()
provider.insert_chunks = AsyncMock()
provider.query_chunks = AsyncMock(return_value=QueryChunksResponse(chunks=[], scores=[]))
provider.openai_retrieve_vector_store = AsyncMock(
return_value=VectorStoreObject(
id="vs_123",
created_at=1234567890,
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
)
provider.openai_update_vector_store = AsyncMock(
return_value=VectorStoreObject(
id="vs_123",
created_at=1234567890,
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
)
provider.openai_delete_vector_store = AsyncMock(return_value=VectorStoreDeleteResponse(id="vs_123", deleted=True))
provider.openai_search_vector_store = AsyncMock(
return_value=VectorStoreSearchResponsePage(search_query=["test"], data=[], has_more=False)
)
provider.openai_attach_file_to_vector_store = AsyncMock(
return_value=VectorStoreFileObject(
id="file_123",
chunking_strategy=VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig()),
created_at=1234567890,
status="completed",
vector_store_id="vs_123",
)
)
provider.openai_list_files_in_vector_store = AsyncMock(
return_value=VectorStoreListFilesResponse(data=[], has_more=False)
)
provider.openai_retrieve_vector_store_file = AsyncMock(
return_value=VectorStoreFileObject(
id="file_123",
chunking_strategy=VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig()),
created_at=1234567890,
status="completed",
vector_store_id="vs_123",
)
)
provider.openai_update_vector_store_file = AsyncMock(
return_value=VectorStoreFileObject(
id="file_123",
chunking_strategy=VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig()),
created_at=1234567890,
status="completed",
vector_store_id="vs_123",
)
)
provider.openai_delete_vector_store_file = AsyncMock(
return_value=VectorStoreFileDeleteResponse(id="file_123", deleted=True)
)
provider.openai_create_vector_store_file_batch = AsyncMock(
return_value=VectorStoreFileBatchObject(
id="batch_123",
created_at=1234567890,
vector_store_id="vs_123",
status="in_progress",
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=2, total=2),
)
)
provider.openai_retrieve_vector_store_file_batch = AsyncMock(
return_value=VectorStoreFileBatchObject(
id="batch_123",
created_at=1234567890,
vector_store_id="vs_123",
status="completed",
file_counts=VectorStoreFileCounts(completed=2, cancelled=0, failed=0, in_progress=0, total=2),
)
)
provider.openai_list_files_in_vector_store_file_batch = AsyncMock(
return_value=VectorStoreFilesListInBatchResponse(data=[], has_more=False)
)
provider.openai_cancel_vector_store_file_batch = AsyncMock(
return_value=VectorStoreFileBatchObject(
id="batch_123",
created_at=1234567890,
vector_store_id="vs_123",
status="cancelled",
file_counts=VectorStoreFileCounts(completed=0, cancelled=2, failed=0, in_progress=0, total=2),
)
)
return provider
@pytest.fixture
def router_with_real_routing_table(mock_provider):
"""Create router with real routing table for integration testing."""
mock_dist_registry = MockDistRegistry()
routing_table = VectorStoresRoutingTable(
impls_by_provider_id={"test-provider": mock_provider},
dist_registry=mock_dist_registry,
policy=[],
)
# Mock get_provider_impl to return our mock provider
routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
# Mock get_object_by_identifier to return a mock vector store object
# This is needed by assert_action_allowed to check permissions
from llama_stack.core.datatypes import VectorStoreWithOwner
from llama_stack_api import ResourceType
mock_vector_store = VectorStoreWithOwner(
identifier="vs_123",
provider_id="test-provider",
provider_resource_id="vs_123",
type=ResourceType.vector_store,
embedding_model="test-model",
embedding_dimension=768,
)
routing_table.get_object_by_identifier = AsyncMock(return_value=mock_vector_store)
# Spy on assert_action_allowed to verify it's called correctly
original_assert = routing_table.assert_action_allowed
routing_table.assert_action_allowed = AsyncMock(wraps=original_assert)
# Create router with real routing table
router = VectorIORouter(routing_table)
return router, routing_table, mock_provider
@pytest.mark.parametrize(
"operation_name,expected_action,router_call,provider_method",
[
(
"insert_chunks",
"update",
lambda r: r.insert_chunks("vs_123", [Chunk(content="test", chunk_id="c1")]),
"insert_chunks",
),
(
"query_chunks",
"read",
lambda r: r.query_chunks("vs_123", "test"),
"query_chunks",
),
(
"openai_retrieve_vector_store",
"read",
lambda r: r.openai_retrieve_vector_store("vs_123"),
"openai_retrieve_vector_store",
),
(
"openai_update_vector_store",
"update",
lambda r: r.openai_update_vector_store("vs_123", name="test"),
"openai_update_vector_store",
),
(
"openai_delete_vector_store",
"delete",
lambda r: r.openai_delete_vector_store("vs_123"),
"openai_delete_vector_store",
),
(
"openai_search_vector_store",
"read",
lambda r: r.openai_search_vector_store("vs_123", query="test"),
"openai_search_vector_store",
),
(
"openai_attach_file_to_vector_store",
"update",
lambda r: r.openai_attach_file_to_vector_store("vs_123", "file_123"),
"openai_attach_file_to_vector_store",
),
(
"openai_list_files_in_vector_store",
"read",
lambda r: r.openai_list_files_in_vector_store("vs_123"),
"openai_list_files_in_vector_store",
),
(
"openai_retrieve_vector_store_file",
"read",
lambda r: r.openai_retrieve_vector_store_file("vs_123", "file_123"),
"openai_retrieve_vector_store_file",
),
(
"openai_update_vector_store_file",
"update",
lambda r: r.openai_update_vector_store_file("vs_123", "file_123", {}),
"openai_update_vector_store_file",
),
(
"openai_delete_vector_store_file",
"delete",
lambda r: r.openai_delete_vector_store_file("vs_123", "file_123"),
"openai_delete_vector_store_file",
),
(
"openai_create_vector_store_file_batch",
"update",
lambda r: r.openai_create_vector_store_file_batch(
"vs_123", OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["f1"])
),
"openai_create_vector_store_file_batch",
),
(
"openai_retrieve_vector_store_file_batch",
"read",
lambda r: r.openai_retrieve_vector_store_file_batch("batch_123", "vs_123"),
"openai_retrieve_vector_store_file_batch",
),
(
"openai_list_files_in_vector_store_file_batch",
"read",
lambda r: r.openai_list_files_in_vector_store_file_batch("batch_123", "vs_123"),
"openai_list_files_in_vector_store_file_batch",
),
(
"openai_cancel_vector_store_file_batch",
"update",
lambda r: r.openai_cancel_vector_store_file_batch("batch_123", "vs_123"),
"openai_cancel_vector_store_file_batch",
),
],
)
async def test_operation_enforces_correct_abac_permission(
operation_name, expected_action, router_call, provider_method, router_with_real_routing_table
):
"""Test that each operation flows through router -> routing table -> ABAC check -> provider.
This verifies:
1. Router delegates to routing table (not directly to provider)
2. Routing table enforces ABAC with correct action
3. Provider is called only after authorization succeeds
"""
router, routing_table, mock_provider = router_with_real_routing_table
# Execute the operation through the router
await router_call(router)
# Verify ABAC check happened with correct action and resource
routing_table.assert_action_allowed.assert_called_once_with(expected_action, "vector_store", "vs_123")
# Verify provider was called after authorization
provider_mock = getattr(mock_provider, provider_method)
provider_mock.assert_called_once()
async def test_operations_fail_before_provider_when_unauthorized(router_with_real_routing_table):
"""Test that all operations fail at ABAC check before calling provider when unauthorized."""
router, routing_table, mock_provider = router_with_real_routing_table
# Make assert_action_allowed raise PermissionError
routing_table.assert_action_allowed.side_effect = PermissionError("Access denied")
# Test all operations fail before reaching provider
operations = [
("insert_chunks", lambda: router.insert_chunks("vs_123", [Chunk(content="test", chunk_id="c1")])),
("query_chunks", lambda: router.query_chunks("vs_123", "test")),
("openai_retrieve_vector_store", lambda: router.openai_retrieve_vector_store("vs_123")),
("openai_update_vector_store", lambda: router.openai_update_vector_store("vs_123", name="test")),
("openai_delete_vector_store", lambda: router.openai_delete_vector_store("vs_123")),
("openai_search_vector_store", lambda: router.openai_search_vector_store("vs_123", query="test")),
("openai_attach_file_to_vector_store", lambda: router.openai_attach_file_to_vector_store("vs_123", "file_123")),
("openai_list_files_in_vector_store", lambda: router.openai_list_files_in_vector_store("vs_123")),
(
"openai_retrieve_vector_store_file",
lambda: router.openai_retrieve_vector_store_file("vs_123", "file_123"),
),
(
"openai_update_vector_store_file",
lambda: router.openai_update_vector_store_file("vs_123", "file_123", {}),
),
(
"openai_delete_vector_store_file",
lambda: router.openai_delete_vector_store_file("vs_123", "file_123"),
),
(
"openai_create_vector_store_file_batch",
lambda: router.openai_create_vector_store_file_batch(
"vs_123", OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["f1"])
),
),
(
"openai_retrieve_vector_store_file_batch",
lambda: router.openai_retrieve_vector_store_file_batch("batch_123", "vs_123"),
),
(
"openai_list_files_in_vector_store_file_batch",
lambda: router.openai_list_files_in_vector_store_file_batch("batch_123", "vs_123"),
),
(
"openai_cancel_vector_store_file_batch",
lambda: router.openai_cancel_vector_store_file_batch("batch_123", "vs_123"),
),
]
for op_name, op_func in operations:
# Reset mocks
routing_table.assert_action_allowed.reset_mock()
routing_table.assert_action_allowed.side_effect = PermissionError("Access denied")
# Operation should fail with PermissionError
with pytest.raises(PermissionError, match="Access denied"):
await op_func()
# Verify ABAC check was called
assert routing_table.assert_action_allowed.called, f"{op_name} should check permissions"
# Verify provider was NEVER called (all 15 operations)
mock_provider.insert_chunks.assert_not_called()
mock_provider.query_chunks.assert_not_called()
mock_provider.openai_retrieve_vector_store.assert_not_called()
mock_provider.openai_update_vector_store.assert_not_called()
mock_provider.openai_delete_vector_store.assert_not_called()
mock_provider.openai_search_vector_store.assert_not_called()
mock_provider.openai_attach_file_to_vector_store.assert_not_called()
mock_provider.openai_list_files_in_vector_store.assert_not_called()
mock_provider.openai_retrieve_vector_store_file.assert_not_called()
mock_provider.openai_update_vector_store_file.assert_not_called()
mock_provider.openai_delete_vector_store_file.assert_not_called()
mock_provider.openai_create_vector_store_file_batch.assert_not_called()
mock_provider.openai_retrieve_vector_store_file_batch.assert_not_called()
mock_provider.openai_list_files_in_vector_store_file_batch.assert_not_called()
mock_provider.openai_cancel_vector_store_file_batch.assert_not_called()