mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 11:42:36 +00:00
fix: ABAC bypass in vector store operations (#4394)
Vector store operations were bypassing ABAC checks by calling providers directly instead of going through the routing table. This allowed unauthorized access to vector store data and operations. Changes: o Route all VectorIORouter methods through routing table instead of directly to providers o Update routing table to enforce ABAC checks on all vector store operations (read, update, delete) o Add test suite verifying ABAC enforcement for all vector store operations o Ensure providers are never called when authorization fails Fixes security issue where users could access vector stores they don't have permission for. Fixes: #4393 Signed-off-by: Derek Higgins <derekh@redhat.com>
This commit is contained in:
parent
401d3b8ce6
commit
5abb7df41a
4 changed files with 429 additions and 73 deletions
|
|
@ -132,8 +132,7 @@ class VectorIORouter(VectorIO):
|
||||||
f"VectorIORouter.insert_chunks: {vector_store_id}, {len(chunks)} chunks, "
|
f"VectorIORouter.insert_chunks: {vector_store_id}, {len(chunks)} chunks, "
|
||||||
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
|
f"ttl_seconds={ttl_seconds}, chunk_ids={doc_ids}{' and more...' if len(chunks) > 3 else ''}"
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.insert_chunks(vector_store_id, chunks, ttl_seconds)
|
||||||
return await provider.insert_chunks(vector_store_id, chunks, ttl_seconds)
|
|
||||||
|
|
||||||
async def query_chunks(
|
async def query_chunks(
|
||||||
self,
|
self,
|
||||||
|
|
@ -142,8 +141,7 @@ class VectorIORouter(VectorIO):
|
||||||
params: dict[str, Any] | None = None,
|
params: dict[str, Any] | None = None,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
|
logger.debug(f"VectorIORouter.query_chunks: {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.query_chunks(vector_store_id, query, params)
|
||||||
return await provider.query_chunks(vector_store_id, query, params)
|
|
||||||
|
|
||||||
# OpenAI Vector Stores API endpoints
|
# OpenAI Vector Stores API endpoints
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
|
|
@ -248,9 +246,8 @@ class VectorIORouter(VectorIO):
|
||||||
all_stores = []
|
all_stores = []
|
||||||
for vector_store in vector_stores:
|
for vector_store in vector_stores:
|
||||||
try:
|
try:
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store.identifier)
|
vector_store_obj = await self.routing_table.openai_retrieve_vector_store(vector_store.identifier)
|
||||||
vector_store = await provider.openai_retrieve_vector_store(vector_store.identifier)
|
all_stores.append(vector_store_obj)
|
||||||
all_stores.append(vector_store)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
|
logger.error(f"Error retrieving vector store {vector_store.identifier}: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
@ -292,8 +289,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
async def openai_update_vector_store(
|
||||||
self,
|
self,
|
||||||
|
|
@ -310,8 +306,7 @@ class VectorIORouter(VectorIO):
|
||||||
if current_store and current_store.provider_id != metadata["provider_id"]:
|
if current_store and current_store.provider_id != metadata["provider_id"]:
|
||||||
raise ValueError("provider_id cannot be changed after vector store creation")
|
raise ValueError("provider_id cannot be changed after vector store creation")
|
||||||
|
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_update_vector_store(
|
||||||
return await provider.openai_update_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
name=name,
|
name=name,
|
||||||
expires_after=expires_after,
|
expires_after=expires_after,
|
||||||
|
|
@ -346,8 +341,7 @@ class VectorIORouter(VectorIO):
|
||||||
original_query = query
|
original_query = query
|
||||||
search_query = await self._rewrite_query_for_search(original_query)
|
search_query = await self._rewrite_query_for_search(original_query)
|
||||||
|
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_search_vector_store(
|
||||||
return await provider.openai_search_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
query=search_query,
|
query=search_query,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
|
|
@ -367,8 +361,7 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||||
if chunking_strategy is None or chunking_strategy.type == "auto":
|
if chunking_strategy is None or chunking_strategy.type == "auto":
|
||||||
chunking_strategy = VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig())
|
chunking_strategy = VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig())
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||||
return await provider.openai_attach_file_to_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
|
|
@ -385,8 +378,7 @@ class VectorIORouter(VectorIO):
|
||||||
filter: VectorStoreFileStatus | None = None,
|
filter: VectorStoreFileStatus | None = None,
|
||||||
) -> list[VectorStoreFileObject]:
|
) -> list[VectorStoreFileObject]:
|
||||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_list_files_in_vector_store(
|
||||||
return await provider.openai_list_files_in_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
order=order,
|
order=order,
|
||||||
|
|
@ -401,8 +393,7 @@ class VectorIORouter(VectorIO):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||||
return await provider.openai_retrieve_vector_store_file(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
)
|
)
|
||||||
|
|
@ -433,8 +424,7 @@ class VectorIORouter(VectorIO):
|
||||||
attributes: dict[str, Any],
|
attributes: dict[str, Any],
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_update_vector_store_file(
|
||||||
return await provider.openai_update_vector_store_file(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
|
|
@ -446,8 +436,7 @@ class VectorIORouter(VectorIO):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileDeleteResponse:
|
) -> VectorStoreFileDeleteResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_delete_vector_store_file(
|
||||||
return await provider.openai_delete_vector_store_file(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
)
|
)
|
||||||
|
|
@ -483,8 +472,10 @@ class VectorIORouter(VectorIO):
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
|
f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files"
|
||||||
)
|
)
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_create_vector_store_file_batch(
|
||||||
return await provider.openai_create_vector_store_file_batch(vector_store_id, params)
|
vector_store_id=vector_store_id,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
async def openai_retrieve_vector_store_file_batch(
|
async def openai_retrieve_vector_store_file_batch(
|
||||||
self,
|
self,
|
||||||
|
|
@ -492,8 +483,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreFileBatchObject:
|
) -> VectorStoreFileBatchObject:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_retrieve_vector_store_file_batch(
|
||||||
return await provider.openai_retrieve_vector_store_file_batch(
|
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
)
|
)
|
||||||
|
|
@ -509,8 +499,7 @@ class VectorIORouter(VectorIO):
|
||||||
order: str | None = "desc",
|
order: str | None = "desc",
|
||||||
) -> VectorStoreFilesListInBatchResponse:
|
) -> VectorStoreFilesListInBatchResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_list_files_in_vector_store_file_batch(
|
||||||
return await provider.openai_list_files_in_vector_store_file_batch(
|
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
after=after,
|
after=after,
|
||||||
|
|
@ -526,8 +515,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreFileBatchObject:
|
) -> VectorStoreFileBatchObject:
|
||||||
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_cancel_vector_store_file_batch: {batch_id}, {vector_store_id}")
|
||||||
provider = await self.routing_table.get_provider_impl(vector_store_id)
|
return await self.routing_table.openai_cancel_vector_store_file_batch(
|
||||||
return await provider.openai_cancel_vector_store_file_batch(
|
|
||||||
batch_id=batch_id,
|
batch_id=batch_id,
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,13 @@ from llama_stack.log import get_logger
|
||||||
|
|
||||||
# Removed VectorStores import to avoid exposing public API
|
# Removed VectorStores import to avoid exposing public API
|
||||||
from llama_stack_api import (
|
from llama_stack_api import (
|
||||||
|
Chunk,
|
||||||
|
InterleavedContent,
|
||||||
ModelNotFoundError,
|
ModelNotFoundError,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelTypeError,
|
ModelTypeError,
|
||||||
|
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||||
|
QueryChunksResponse,
|
||||||
ResourceType,
|
ResourceType,
|
||||||
SearchRankingOptions,
|
SearchRankingOptions,
|
||||||
VectorStoreChunkingStrategy,
|
VectorStoreChunkingStrategy,
|
||||||
|
|
@ -87,6 +91,26 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
||||||
await self.register_object(vector_store)
|
await self.register_object(vector_store)
|
||||||
return vector_store
|
return vector_store
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.insert_chunks(vector_store_id, chunks, ttl_seconds)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
|
return await provider.query_chunks(vector_store_id, query, params)
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
async def openai_retrieve_vector_store(
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
|
|
@ -142,20 +166,6 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
||||||
search_mode: str | None = "vector",
|
search_mode: str | None = "vector",
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
await self.assert_action_allowed("read", "vector_store", vector_store_id)
|
||||||
|
|
||||||
# Delegate to VectorIORouter if available (which handles query rewriting)
|
|
||||||
if self.vector_io_router is not None:
|
|
||||||
return await self.vector_io_router.openai_search_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
|
||||||
query=query,
|
|
||||||
filters=filters,
|
|
||||||
max_num_results=max_num_results,
|
|
||||||
ranking_options=ranking_options,
|
|
||||||
rewrite_query=rewrite_query,
|
|
||||||
search_mode=search_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Fallback to direct provider call if VectorIORouter not available
|
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_search_vector_store(
|
return await provider.openai_search_vector_store(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
|
|
@ -261,17 +271,13 @@ class VectorStoresRoutingTable(CommonRoutingTableImpl):
|
||||||
async def openai_create_vector_store_file_batch(
|
async def openai_create_vector_store_file_batch(
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
file_ids: list[str],
|
params: OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||||
attributes: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: Any | None = None,
|
|
||||||
):
|
):
|
||||||
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
await self.assert_action_allowed("update", "vector_store", vector_store_id)
|
||||||
provider = await self.get_provider_impl(vector_store_id)
|
provider = await self.get_provider_impl(vector_store_id)
|
||||||
return await provider.openai_create_vector_store_file_batch(
|
return await provider.openai_create_vector_store_file_batch(
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_ids=file_ids,
|
params=params,
|
||||||
attributes=attributes,
|
|
||||||
chunking_strategy=chunking_strategy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def openai_retrieve_vector_store_file_batch(
|
async def openai_retrieve_vector_store_file_batch(
|
||||||
|
|
|
||||||
|
|
@ -105,9 +105,7 @@ async def test_update_vector_store_same_provider_id_succeeds():
|
||||||
mock_existing_store.identifier = "vs_123"
|
mock_existing_store.identifier = "vs_123"
|
||||||
|
|
||||||
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
mock_routing_table.get_object_by_identifier = AsyncMock(return_value=mock_existing_store)
|
||||||
mock_routing_table.get_provider_impl = AsyncMock(
|
mock_routing_table.openai_update_vector_store = AsyncMock(return_value=Mock(identifier="vs_123"))
|
||||||
return_value=Mock(openai_update_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
|
||||||
)
|
|
||||||
|
|
||||||
router = VectorIORouter(mock_routing_table)
|
router = VectorIORouter(mock_routing_table)
|
||||||
|
|
||||||
|
|
@ -118,10 +116,8 @@ async def test_update_vector_store_same_provider_id_succeeds():
|
||||||
metadata={"provider_id": "inline::faiss"}, # Same provider_id
|
metadata={"provider_id": "inline::faiss"}, # Same provider_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify the provider update method was called
|
# Verify the routing table method was called
|
||||||
mock_routing_table.get_provider_impl.assert_called_once_with("vs_123")
|
mock_routing_table.openai_update_vector_store.assert_called_once_with(
|
||||||
provider = await mock_routing_table.get_provider_impl("vs_123")
|
|
||||||
provider.openai_update_vector_store.assert_called_once_with(
|
|
||||||
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
|
vector_store_id="vs_123", name="updated_name", expires_after=None, metadata={"provider_id": "inline::faiss"}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -165,11 +161,9 @@ async def test_query_rewrite_functionality():
|
||||||
|
|
||||||
mock_routing_table = Mock()
|
mock_routing_table = Mock()
|
||||||
|
|
||||||
# Mock provider that returns search results
|
# Mock routing table method that returns search results
|
||||||
mock_provider = Mock()
|
|
||||||
mock_search_response = VectorStoreSearchResponsePage(search_query=["rewritten test query"], data=[], has_more=False)
|
mock_search_response = VectorStoreSearchResponsePage(search_query=["rewritten test query"], data=[], has_more=False)
|
||||||
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
mock_routing_table.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
||||||
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
|
||||||
|
|
||||||
# Mock inference API for query rewriting
|
# Mock inference API for query rewriting
|
||||||
mock_inference_api = Mock()
|
mock_inference_api = Mock()
|
||||||
|
|
@ -206,9 +200,9 @@ async def test_query_rewrite_functionality():
|
||||||
expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query")
|
expected_prompt = DEFAULT_QUERY_REWRITE_PROMPT.format(query="test query")
|
||||||
assert prompt_text == expected_prompt
|
assert prompt_text == expected_prompt
|
||||||
|
|
||||||
# Verify provider was called with rewritten query and rewrite_query=False
|
# Verify routing table was called with rewritten query and rewrite_query=False
|
||||||
mock_provider.openai_search_vector_store.assert_called_once()
|
mock_routing_table.openai_search_vector_store.assert_called_once()
|
||||||
call_kwargs = mock_provider.openai_search_vector_store.call_args.kwargs
|
call_kwargs = mock_routing_table.openai_search_vector_store.call_args.kwargs
|
||||||
assert call_kwargs["query"] == "rewritten test query"
|
assert call_kwargs["query"] == "rewritten test query"
|
||||||
assert call_kwargs["rewrite_query"] is False # Should be False since router handled it
|
assert call_kwargs["rewrite_query"] is False # Should be False since router handled it
|
||||||
|
|
||||||
|
|
@ -242,10 +236,8 @@ async def test_query_rewrite_with_custom_prompt():
|
||||||
|
|
||||||
mock_routing_table = Mock()
|
mock_routing_table = Mock()
|
||||||
|
|
||||||
mock_provider = Mock()
|
|
||||||
mock_search_response = VectorStoreSearchResponsePage(search_query=["custom rewrite"], data=[], has_more=False)
|
mock_search_response = VectorStoreSearchResponsePage(search_query=["custom rewrite"], data=[], has_more=False)
|
||||||
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
mock_routing_table.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
||||||
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
|
||||||
|
|
||||||
mock_inference_api = Mock()
|
mock_inference_api = Mock()
|
||||||
mock_inference_api.openai_chat_completion = AsyncMock(
|
mock_inference_api.openai_chat_completion = AsyncMock(
|
||||||
|
|
@ -283,10 +275,8 @@ async def test_search_without_rewrite():
|
||||||
|
|
||||||
mock_routing_table = Mock()
|
mock_routing_table = Mock()
|
||||||
|
|
||||||
mock_provider = Mock()
|
|
||||||
mock_search_response = VectorStoreSearchResponsePage(search_query=["test query"], data=[], has_more=False)
|
mock_search_response = VectorStoreSearchResponsePage(search_query=["test query"], data=[], has_more=False)
|
||||||
mock_provider.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
mock_routing_table.openai_search_vector_store = AsyncMock(return_value=mock_search_response)
|
||||||
mock_routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
|
||||||
|
|
||||||
mock_inference_api = Mock()
|
mock_inference_api = Mock()
|
||||||
mock_inference_api.openai_chat_completion = AsyncMock()
|
mock_inference_api.openai_chat_completion = AsyncMock()
|
||||||
|
|
@ -303,6 +293,6 @@ async def test_search_without_rewrite():
|
||||||
# Verify inference API was NOT called
|
# Verify inference API was NOT called
|
||||||
assert not mock_inference_api.openai_chat_completion.called
|
assert not mock_inference_api.openai_chat_completion.called
|
||||||
|
|
||||||
# Verify provider was called with original query
|
# Verify routing table was called with original query
|
||||||
call_kwargs = mock_provider.openai_search_vector_store.call_args.kwargs
|
call_kwargs = mock_routing_table.openai_search_vector_store.call_args.kwargs
|
||||||
assert call_kwargs["query"] == "test query"
|
assert call_kwargs["query"] == "test query"
|
||||||
|
|
|
||||||
372
tests/unit/core/routers/test_vector_stores_abac.py
Normal file
372
tests/unit/core/routers/test_vector_stores_abac.py
Normal file
|
|
@ -0,0 +1,372 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Tests for ABAC enforcement in vector store operations.
|
||||||
|
|
||||||
|
This test suite verifies that all vector store operations properly enforce
|
||||||
|
authorization checks through the router -> routing table -> ABAC flow.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.core.routers.vector_io import VectorIORouter
|
||||||
|
from llama_stack.core.routing_tables.vector_stores import VectorStoresRoutingTable
|
||||||
|
from llama_stack_api import (
|
||||||
|
Chunk,
|
||||||
|
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
|
||||||
|
QueryChunksResponse,
|
||||||
|
VectorStoreChunkingStrategyStatic,
|
||||||
|
VectorStoreChunkingStrategyStaticConfig,
|
||||||
|
VectorStoreDeleteResponse,
|
||||||
|
VectorStoreFileBatchObject,
|
||||||
|
VectorStoreFileCounts,
|
||||||
|
VectorStoreFileDeleteResponse,
|
||||||
|
VectorStoreFileObject,
|
||||||
|
VectorStoreFilesListInBatchResponse,
|
||||||
|
VectorStoreListFilesResponse,
|
||||||
|
VectorStoreObject,
|
||||||
|
VectorStoreSearchResponsePage,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MockDistRegistry:
|
||||||
|
"""Mock distribution registry for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.dist = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_provider():
|
||||||
|
"""Create a mock provider that returns valid responses for all operations."""
|
||||||
|
provider = Mock()
|
||||||
|
|
||||||
|
provider.insert_chunks = AsyncMock()
|
||||||
|
provider.query_chunks = AsyncMock(return_value=QueryChunksResponse(chunks=[], scores=[]))
|
||||||
|
provider.openai_retrieve_vector_store = AsyncMock(
|
||||||
|
return_value=VectorStoreObject(
|
||||||
|
id="vs_123",
|
||||||
|
created_at=1234567890,
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider.openai_update_vector_store = AsyncMock(
|
||||||
|
return_value=VectorStoreObject(
|
||||||
|
id="vs_123",
|
||||||
|
created_at=1234567890,
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider.openai_delete_vector_store = AsyncMock(return_value=VectorStoreDeleteResponse(id="vs_123", deleted=True))
|
||||||
|
provider.openai_search_vector_store = AsyncMock(
|
||||||
|
return_value=VectorStoreSearchResponsePage(search_query=["test"], data=[], has_more=False)
|
||||||
|
)
|
||||||
|
provider.openai_attach_file_to_vector_store = AsyncMock(
|
||||||
|
return_value=VectorStoreFileObject(
|
||||||
|
id="file_123",
|
||||||
|
chunking_strategy=VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig()),
|
||||||
|
created_at=1234567890,
|
||||||
|
status="completed",
|
||||||
|
vector_store_id="vs_123",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider.openai_list_files_in_vector_store = AsyncMock(
|
||||||
|
return_value=VectorStoreListFilesResponse(data=[], has_more=False)
|
||||||
|
)
|
||||||
|
provider.openai_retrieve_vector_store_file = AsyncMock(
|
||||||
|
return_value=VectorStoreFileObject(
|
||||||
|
id="file_123",
|
||||||
|
chunking_strategy=VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig()),
|
||||||
|
created_at=1234567890,
|
||||||
|
status="completed",
|
||||||
|
vector_store_id="vs_123",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider.openai_update_vector_store_file = AsyncMock(
|
||||||
|
return_value=VectorStoreFileObject(
|
||||||
|
id="file_123",
|
||||||
|
chunking_strategy=VectorStoreChunkingStrategyStatic(static=VectorStoreChunkingStrategyStaticConfig()),
|
||||||
|
created_at=1234567890,
|
||||||
|
status="completed",
|
||||||
|
vector_store_id="vs_123",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider.openai_delete_vector_store_file = AsyncMock(
|
||||||
|
return_value=VectorStoreFileDeleteResponse(id="file_123", deleted=True)
|
||||||
|
)
|
||||||
|
provider.openai_create_vector_store_file_batch = AsyncMock(
|
||||||
|
return_value=VectorStoreFileBatchObject(
|
||||||
|
id="batch_123",
|
||||||
|
created_at=1234567890,
|
||||||
|
vector_store_id="vs_123",
|
||||||
|
status="in_progress",
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=2, total=2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider.openai_retrieve_vector_store_file_batch = AsyncMock(
|
||||||
|
return_value=VectorStoreFileBatchObject(
|
||||||
|
id="batch_123",
|
||||||
|
created_at=1234567890,
|
||||||
|
vector_store_id="vs_123",
|
||||||
|
status="completed",
|
||||||
|
file_counts=VectorStoreFileCounts(completed=2, cancelled=0, failed=0, in_progress=0, total=2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
provider.openai_list_files_in_vector_store_file_batch = AsyncMock(
|
||||||
|
return_value=VectorStoreFilesListInBatchResponse(data=[], has_more=False)
|
||||||
|
)
|
||||||
|
provider.openai_cancel_vector_store_file_batch = AsyncMock(
|
||||||
|
return_value=VectorStoreFileBatchObject(
|
||||||
|
id="batch_123",
|
||||||
|
created_at=1234567890,
|
||||||
|
vector_store_id="vs_123",
|
||||||
|
status="cancelled",
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=2, failed=0, in_progress=0, total=2),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router_with_real_routing_table(mock_provider):
|
||||||
|
"""Create router with real routing table for integration testing."""
|
||||||
|
mock_dist_registry = MockDistRegistry()
|
||||||
|
|
||||||
|
routing_table = VectorStoresRoutingTable(
|
||||||
|
impls_by_provider_id={"test-provider": mock_provider},
|
||||||
|
dist_registry=mock_dist_registry,
|
||||||
|
policy=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock get_provider_impl to return our mock provider
|
||||||
|
routing_table.get_provider_impl = AsyncMock(return_value=mock_provider)
|
||||||
|
|
||||||
|
# Mock get_object_by_identifier to return a mock vector store object
|
||||||
|
# This is needed by assert_action_allowed to check permissions
|
||||||
|
from llama_stack.core.datatypes import VectorStoreWithOwner
|
||||||
|
from llama_stack_api import ResourceType
|
||||||
|
|
||||||
|
mock_vector_store = VectorStoreWithOwner(
|
||||||
|
identifier="vs_123",
|
||||||
|
provider_id="test-provider",
|
||||||
|
provider_resource_id="vs_123",
|
||||||
|
type=ResourceType.vector_store,
|
||||||
|
embedding_model="test-model",
|
||||||
|
embedding_dimension=768,
|
||||||
|
)
|
||||||
|
routing_table.get_object_by_identifier = AsyncMock(return_value=mock_vector_store)
|
||||||
|
|
||||||
|
# Spy on assert_action_allowed to verify it's called correctly
|
||||||
|
original_assert = routing_table.assert_action_allowed
|
||||||
|
routing_table.assert_action_allowed = AsyncMock(wraps=original_assert)
|
||||||
|
|
||||||
|
# Create router with real routing table
|
||||||
|
router = VectorIORouter(routing_table)
|
||||||
|
|
||||||
|
return router, routing_table, mock_provider
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"operation_name,expected_action,router_call,provider_method",
|
||||||
|
[
|
||||||
|
(
|
||||||
|
"insert_chunks",
|
||||||
|
"update",
|
||||||
|
lambda r: r.insert_chunks("vs_123", [Chunk(content="test", chunk_id="c1")]),
|
||||||
|
"insert_chunks",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"query_chunks",
|
||||||
|
"read",
|
||||||
|
lambda r: r.query_chunks("vs_123", "test"),
|
||||||
|
"query_chunks",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_retrieve_vector_store",
|
||||||
|
"read",
|
||||||
|
lambda r: r.openai_retrieve_vector_store("vs_123"),
|
||||||
|
"openai_retrieve_vector_store",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_update_vector_store",
|
||||||
|
"update",
|
||||||
|
lambda r: r.openai_update_vector_store("vs_123", name="test"),
|
||||||
|
"openai_update_vector_store",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_delete_vector_store",
|
||||||
|
"delete",
|
||||||
|
lambda r: r.openai_delete_vector_store("vs_123"),
|
||||||
|
"openai_delete_vector_store",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_search_vector_store",
|
||||||
|
"read",
|
||||||
|
lambda r: r.openai_search_vector_store("vs_123", query="test"),
|
||||||
|
"openai_search_vector_store",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_attach_file_to_vector_store",
|
||||||
|
"update",
|
||||||
|
lambda r: r.openai_attach_file_to_vector_store("vs_123", "file_123"),
|
||||||
|
"openai_attach_file_to_vector_store",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_list_files_in_vector_store",
|
||||||
|
"read",
|
||||||
|
lambda r: r.openai_list_files_in_vector_store("vs_123"),
|
||||||
|
"openai_list_files_in_vector_store",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_retrieve_vector_store_file",
|
||||||
|
"read",
|
||||||
|
lambda r: r.openai_retrieve_vector_store_file("vs_123", "file_123"),
|
||||||
|
"openai_retrieve_vector_store_file",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_update_vector_store_file",
|
||||||
|
"update",
|
||||||
|
lambda r: r.openai_update_vector_store_file("vs_123", "file_123", {}),
|
||||||
|
"openai_update_vector_store_file",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_delete_vector_store_file",
|
||||||
|
"delete",
|
||||||
|
lambda r: r.openai_delete_vector_store_file("vs_123", "file_123"),
|
||||||
|
"openai_delete_vector_store_file",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_create_vector_store_file_batch",
|
||||||
|
"update",
|
||||||
|
lambda r: r.openai_create_vector_store_file_batch(
|
||||||
|
"vs_123", OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["f1"])
|
||||||
|
),
|
||||||
|
"openai_create_vector_store_file_batch",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_retrieve_vector_store_file_batch",
|
||||||
|
"read",
|
||||||
|
lambda r: r.openai_retrieve_vector_store_file_batch("batch_123", "vs_123"),
|
||||||
|
"openai_retrieve_vector_store_file_batch",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_list_files_in_vector_store_file_batch",
|
||||||
|
"read",
|
||||||
|
lambda r: r.openai_list_files_in_vector_store_file_batch("batch_123", "vs_123"),
|
||||||
|
"openai_list_files_in_vector_store_file_batch",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_cancel_vector_store_file_batch",
|
||||||
|
"update",
|
||||||
|
lambda r: r.openai_cancel_vector_store_file_batch("batch_123", "vs_123"),
|
||||||
|
"openai_cancel_vector_store_file_batch",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_operation_enforces_correct_abac_permission(
|
||||||
|
operation_name, expected_action, router_call, provider_method, router_with_real_routing_table
|
||||||
|
):
|
||||||
|
"""Test that each operation flows through router -> routing table -> ABAC check -> provider.
|
||||||
|
|
||||||
|
This verifies:
|
||||||
|
1. Router delegates to routing table (not directly to provider)
|
||||||
|
2. Routing table enforces ABAC with correct action
|
||||||
|
3. Provider is called only after authorization succeeds
|
||||||
|
"""
|
||||||
|
router, routing_table, mock_provider = router_with_real_routing_table
|
||||||
|
|
||||||
|
# Execute the operation through the router
|
||||||
|
await router_call(router)
|
||||||
|
|
||||||
|
# Verify ABAC check happened with correct action and resource
|
||||||
|
routing_table.assert_action_allowed.assert_called_once_with(expected_action, "vector_store", "vs_123")
|
||||||
|
|
||||||
|
# Verify provider was called after authorization
|
||||||
|
provider_mock = getattr(mock_provider, provider_method)
|
||||||
|
provider_mock.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_operations_fail_before_provider_when_unauthorized(router_with_real_routing_table):
|
||||||
|
"""Test that all operations fail at ABAC check before calling provider when unauthorized."""
|
||||||
|
router, routing_table, mock_provider = router_with_real_routing_table
|
||||||
|
|
||||||
|
# Make assert_action_allowed raise PermissionError
|
||||||
|
routing_table.assert_action_allowed.side_effect = PermissionError("Access denied")
|
||||||
|
|
||||||
|
# Test all operations fail before reaching provider
|
||||||
|
operations = [
|
||||||
|
("insert_chunks", lambda: router.insert_chunks("vs_123", [Chunk(content="test", chunk_id="c1")])),
|
||||||
|
("query_chunks", lambda: router.query_chunks("vs_123", "test")),
|
||||||
|
("openai_retrieve_vector_store", lambda: router.openai_retrieve_vector_store("vs_123")),
|
||||||
|
("openai_update_vector_store", lambda: router.openai_update_vector_store("vs_123", name="test")),
|
||||||
|
("openai_delete_vector_store", lambda: router.openai_delete_vector_store("vs_123")),
|
||||||
|
("openai_search_vector_store", lambda: router.openai_search_vector_store("vs_123", query="test")),
|
||||||
|
("openai_attach_file_to_vector_store", lambda: router.openai_attach_file_to_vector_store("vs_123", "file_123")),
|
||||||
|
("openai_list_files_in_vector_store", lambda: router.openai_list_files_in_vector_store("vs_123")),
|
||||||
|
(
|
||||||
|
"openai_retrieve_vector_store_file",
|
||||||
|
lambda: router.openai_retrieve_vector_store_file("vs_123", "file_123"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_update_vector_store_file",
|
||||||
|
lambda: router.openai_update_vector_store_file("vs_123", "file_123", {}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_delete_vector_store_file",
|
||||||
|
lambda: router.openai_delete_vector_store_file("vs_123", "file_123"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_create_vector_store_file_batch",
|
||||||
|
lambda: router.openai_create_vector_store_file_batch(
|
||||||
|
"vs_123", OpenAICreateVectorStoreFileBatchRequestWithExtraBody(file_ids=["f1"])
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_retrieve_vector_store_file_batch",
|
||||||
|
lambda: router.openai_retrieve_vector_store_file_batch("batch_123", "vs_123"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_list_files_in_vector_store_file_batch",
|
||||||
|
lambda: router.openai_list_files_in_vector_store_file_batch("batch_123", "vs_123"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"openai_cancel_vector_store_file_batch",
|
||||||
|
lambda: router.openai_cancel_vector_store_file_batch("batch_123", "vs_123"),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for op_name, op_func in operations:
|
||||||
|
# Reset mocks
|
||||||
|
routing_table.assert_action_allowed.reset_mock()
|
||||||
|
routing_table.assert_action_allowed.side_effect = PermissionError("Access denied")
|
||||||
|
|
||||||
|
# Operation should fail with PermissionError
|
||||||
|
with pytest.raises(PermissionError, match="Access denied"):
|
||||||
|
await op_func()
|
||||||
|
|
||||||
|
# Verify ABAC check was called
|
||||||
|
assert routing_table.assert_action_allowed.called, f"{op_name} should check permissions"
|
||||||
|
|
||||||
|
# Verify provider was NEVER called (all 15 operations)
|
||||||
|
mock_provider.insert_chunks.assert_not_called()
|
||||||
|
mock_provider.query_chunks.assert_not_called()
|
||||||
|
mock_provider.openai_retrieve_vector_store.assert_not_called()
|
||||||
|
mock_provider.openai_update_vector_store.assert_not_called()
|
||||||
|
mock_provider.openai_delete_vector_store.assert_not_called()
|
||||||
|
mock_provider.openai_search_vector_store.assert_not_called()
|
||||||
|
mock_provider.openai_attach_file_to_vector_store.assert_not_called()
|
||||||
|
mock_provider.openai_list_files_in_vector_store.assert_not_called()
|
||||||
|
mock_provider.openai_retrieve_vector_store_file.assert_not_called()
|
||||||
|
mock_provider.openai_update_vector_store_file.assert_not_called()
|
||||||
|
mock_provider.openai_delete_vector_store_file.assert_not_called()
|
||||||
|
mock_provider.openai_create_vector_store_file_batch.assert_not_called()
|
||||||
|
mock_provider.openai_retrieve_vector_store_file_batch.assert_not_called()
|
||||||
|
mock_provider.openai_list_files_in_vector_store_file_batch.assert_not_called()
|
||||||
|
mock_provider.openai_cancel_vector_store_file_batch.assert_not_called()
|
||||||
Loading…
Add table
Add a link
Reference in a new issue