diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index cd56ada7b..a1dd66060 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -214,9 +214,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store(vector_store_id) + return await self.routing_table.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( self, @@ -226,9 +224,7 @@ class VectorIORouter(VectorIO): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store( + return await self.routing_table.openai_update_vector_store( vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -240,12 +236,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - result = await provider.openai_delete_vector_store(vector_store_id) - # drop from registry - await self.routing_table.unregister_vector_db(vector_store_id) - return result + return await self.routing_table.openai_delete_vector_store(vector_store_id) async def openai_search_vector_store( self, @@ -258,9 +249,7 @@ class VectorIORouter(VectorIO): search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_search_vector_store( + return await self.routing_table.openai_search_vector_store( vector_store_id=vector_store_id, query=query, filters=filters, @@ -278,9 +267,7 @@ class VectorIORouter(VectorIO): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_attach_file_to_vector_store( + return await self.routing_table.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -297,9 +284,7 @@ class VectorIORouter(VectorIO): filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_list_files_in_vector_store( + return await self.routing_table.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, order=order, @@ -314,9 +299,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file( + return await self.routing_table.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) @@ -327,9 +310,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileContentsResponse: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( + return await self.routing_table.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, ) @@ -341,9 +322,7 @@ class VectorIORouter(VectorIO): attributes: dict[str, Any], ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store_file( + return await self.routing_table.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -355,9 +334,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_delete_vector_store_file( + return await self.routing_table.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 7f7de32fe..bbe0113e9 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -9,6 +9,7 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.access_control.datatypes import Action from llama_stack.distribution.datatypes import ( AccessRule, RoutableObject, @@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable): await self.dist_registry.register(obj) return obj + async def assert_action_allowed( + self, + action: Action, + type: str, + identifier: str, + ) -> None: + """Fetch a registered object by type/identifier and enforce the given action permission.""" + obj = await self.get_object_by_identifier(type, identifier) + if obj is None: + raise ValueError(f"{type.capitalize()} '{identifier}' not found") + user = get_authenticated_user() + if not is_action_allowed(self.policy, action, obj, user): + raise AccessDeniedError(action, obj, user) + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() filtered_objs = [obj for obj in objs if obj.type == type] diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index f861102c8..b4e60c625 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -4,11 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any + from pydantic import TypeAdapter from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.apis.vector_io.vector_io import ( + SearchRankingOptions, + VectorStoreChunkingStrategy, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreFileStatus, + VectorStoreObject, + VectorStoreSearchResponsePage, +) from llama_stack.distribution.datatypes import ( VectorDBWithOwner, ) @@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if existing_vector_db is None: raise ValueError(f"Vector DB {vector_db_id} not found") await self.unregister_object(existing_vector_db) + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store( + vector_store_id=vector_store_id, + name=name, + expires_after=expires_after, + metadata=metadata, + ) + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) + await self.unregister_vector_db(vector_store_id) + return result + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int | None = 10, + ranking_options: SearchRankingOptions | None = None, + rewrite_query: bool | None = False, + search_mode: str | None = "vector", + ) -> VectorStoreSearchResponsePage: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=filters, + max_num_results=max_num_results, + ranking_options=ranking_options, + rewrite_query=rewrite_query, + search_mode=search_mode, + ) + + async def openai_attach_file_to_vector_store( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any] | None = None, + chunking_strategy: VectorStoreChunkingStrategy | None = None, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_list_files_in_vector_store( + self, + vector_store_id: str, + limit: int | None = 20, + order: str | None = "desc", + after: str | None = None, + before: str | None = None, + filter: VectorStoreFileStatus | None = None, + ) -> list[VectorStoreFileObject]: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( + vector_store_id=vector_store_id, + limit=limit, + order=order, + after=after, + before=before, + filter=filter, + ) + + async def openai_retrieve_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_retrieve_vector_store_file_contents( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileContentsResponse: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_update_vector_store_file( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any], + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + ) + + async def openai_delete_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 3ba042bd9..30f795d33 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,17 +11,15 @@ from unittest.mock import AsyncMock from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import Model from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter -from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable -from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: @@ -54,17 +52,6 @@ class SafetyImpl(Impl): return shield -class VectorDBImpl(Impl): - def __init__(self): - super().__init__(Api.vector_io) - - async def register_vector_db(self, vector_db: VectorDB): - return vector_db - - async def unregister_vector_db(self, vector_db_id: str): - return vector_db_id - - class DatasetsImpl(Impl): def __init__(self): super().__init__(Api.datasetio) @@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry): assert "test-shield-2" in shield_ids -async def test_vectordbs_routing_table(cached_disk_dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) - await table.initialize() - - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") - vector_dbs = await table.list_vector_dbs() - - assert len(vector_dbs.data) == 2 - vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids - - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") - - vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 0 - - async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() diff --git a/tests/unit/distribution/routing_tables/__init__.py b/tests/unit/distribution/routing_tables/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/distribution/routing_tables/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py new file mode 100644 index 000000000..28887e1cf --- /dev/null +++ b/tests/unit/distribution/routing_tables/test_vector_dbs.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Unit tests for the routing tables vector_dbs + +import time +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.datatypes import Api +from llama_stack.apis.models import ModelType +from llama_stack.apis.vector_dbs.vector_dbs import VectorDB +from llama_stack.apis.vector_io.vector_io import ( + VectorStoreContent, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileCounts, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreObject, + VectorStoreSearchResponsePage, +) +from llama_stack.distribution.access_control.datatypes import AccessRule, Scope +from llama_stack.distribution.datatypes import User +from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable +from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable + + +class VectorDBImpl(Impl): + def __init__(self): + super().__init__(Api.vector_io) + + async def register_vector_db(self, vector_db: VectorDB): + return vector_db + + async def unregister_vector_db(self, vector_db_id: str): + return vector_db_id + + async def openai_retrieve_vector_store(self, vector_store_id): + return VectorStoreObject( + id=vector_store_id, + name="Test Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_update_vector_store(self, vector_store_id, **kwargs): + return VectorStoreObject( + id=vector_store_id, + name="Updated Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_delete_vector_store(self, vector_store_id): + return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True) + + async def openai_search_vector_store(self, vector_store_id, query, **kwargs): + return VectorStoreSearchResponsePage( + object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None + ) + + async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs): + return [ + VectorStoreFileObject( + id="1", + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + ] + + async def openai_retrieve_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id): + return VectorStoreFileContentsResponse( + file_id=file_id, + filename="Sample File name", + attributes={"key": "value"}, + content=[VectorStoreContent(type="text", text="Sample content")], + ) + + async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_delete_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileDeleteResponse(id=file_id, deleted=True) + + +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + # Register multiple vector databases and verify listing + await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") + await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") + vector_dbs = await table.list_vector_dbs() + + assert len(vector_dbs.data) == 2 + vector_db_ids = {v.identifier for v in vector_dbs.data} + assert "test-vectordb" in vector_db_ids + assert "test-vectordb-2" in vector_db_ids + + await table.unregister_vector_db(vector_db_id="test-vectordb") + await table.unregister_vector_db(vector_db_id="test-vectordb-2") + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 0 + + +async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): + impl = VectorDBImpl() + impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[]) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + authorized_table = "vs1" + authorized_team = "team1" + unauthorized_team = "team2" + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) + with request_provider_data_context({}, authorized_user): + _ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + + # Authorized reader + with request_provider_data_context({}, authorized_user): + res = await table.openai_retrieve_vector_store(authorized_table) + assert res == "OK" + + # Authorized updater + impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + assert res == "UPDATED" + + # Unauthorized reader + unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]}) + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_retrieve_vector_store(authorized_table) + + # Unauthorized updater + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + + # Authorized deleter + impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + assert res == "DELETED" + + # Unauthorized deleter + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + + +async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry): + impl = VectorDBImpl() + + policy = [ + AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"), + AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"), + ] + + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + + vector_db_id = "vs1" + file_id = "file-1" + + admin_user = User(principal="admin", attributes={"roles": ["admin"]}) + read_only_user = User(principal="reader", attributes={"roles": ["reader"]}) + no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]}) + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + with request_provider_data_context({}, admin_user): + await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + + read_methods = [ + (table.openai_retrieve_vector_store, (vector_db_id,), {}), + (table.openai_search_vector_store, (vector_db_id, "query"), {}), + (table.openai_list_files_in_vector_store, (vector_db_id,), {}), + (table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}), + ] + update_methods = [ + (table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}), + (table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}), + (table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}), + ] + delete_methods = [ + (table.openai_delete_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_delete_vector_store, (vector_db_id,), {}), + ] + + for user in [admin_user, read_only_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in read_methods: + result = await method(*args, **kwargs) + assert result is not None, f"Read operation failed with user {user.principal}" + + with request_provider_data_context({}, no_access_user): + for method, args, kwargs in read_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs) + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in update_methods: + result = await method(*args, **kwargs) + assert result is not None, "Update operation failed with admin user" + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in delete_methods: + result = await method(*args, **kwargs) + assert result is not None, "Delete operation failed with admin user" + + for user in [read_only_user, no_access_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in delete_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs)