mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-23 21:04:29 +00:00
chore: Adding Access Control for OpenAI Vector Stores methods (#2772)
# What does this PR do? Refactors the vector store routing logic by moving OpenAI-compatible vector store operations from the `VectorIORouter` to the `VectorDBsRoutingTable`. Closes https://github.com/meta-llama/llama-stack/issues/2761 ## Test Plan Added unit tests to cover new routing logic and ACL checks. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
0d7a90b8bc
commit
c8f274347d
6 changed files with 450 additions and 77 deletions
|
@ -214,9 +214,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_retrieve_vector_store(vector_store_id)
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_retrieve_vector_store(vector_store_id)
|
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
async def openai_update_vector_store(
|
||||||
self,
|
self,
|
||||||
|
@ -226,9 +224,7 @@ class VectorIORouter(VectorIO):
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
) -> VectorStoreObject:
|
) -> VectorStoreObject:
|
||||||
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_update_vector_store(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_update_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
name=name,
|
name=name,
|
||||||
expires_after=expires_after,
|
expires_after=expires_after,
|
||||||
|
@ -240,12 +236,7 @@ class VectorIORouter(VectorIO):
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
) -> VectorStoreDeleteResponse:
|
) -> VectorStoreDeleteResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_delete_vector_store(vector_store_id)
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
result = await provider.openai_delete_vector_store(vector_store_id)
|
|
||||||
# drop from registry
|
|
||||||
await self.routing_table.unregister_vector_db(vector_store_id)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
async def openai_search_vector_store(
|
||||||
self,
|
self,
|
||||||
|
@ -258,9 +249,7 @@ class VectorIORouter(VectorIO):
|
||||||
search_mode: str | None = "vector",
|
search_mode: str | None = "vector",
|
||||||
) -> VectorStoreSearchResponsePage:
|
) -> VectorStoreSearchResponsePage:
|
||||||
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_search_vector_store(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_search_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
query=query,
|
query=query,
|
||||||
filters=filters,
|
filters=filters,
|
||||||
|
@ -278,9 +267,7 @@ class VectorIORouter(VectorIO):
|
||||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_attach_file_to_vector_store(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_attach_file_to_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
|
@ -297,9 +284,7 @@ class VectorIORouter(VectorIO):
|
||||||
filter: VectorStoreFileStatus | None = None,
|
filter: VectorStoreFileStatus | None = None,
|
||||||
) -> list[VectorStoreFileObject]:
|
) -> list[VectorStoreFileObject]:
|
||||||
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_list_files_in_vector_store(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_list_files_in_vector_store(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
limit=limit,
|
limit=limit,
|
||||||
order=order,
|
order=order,
|
||||||
|
@ -314,9 +299,7 @@ class VectorIORouter(VectorIO):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_retrieve_vector_store_file(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_retrieve_vector_store_file(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
)
|
)
|
||||||
|
@ -327,9 +310,7 @@ class VectorIORouter(VectorIO):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileContentsResponse:
|
) -> VectorStoreFileContentsResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_retrieve_vector_store_file_contents(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_retrieve_vector_store_file_contents(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
)
|
)
|
||||||
|
@ -341,9 +322,7 @@ class VectorIORouter(VectorIO):
|
||||||
attributes: dict[str, Any],
|
attributes: dict[str, Any],
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_update_vector_store_file(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_update_vector_store_file(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
attributes=attributes,
|
attributes=attributes,
|
||||||
|
@ -355,9 +334,7 @@ class VectorIORouter(VectorIO):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileDeleteResponse:
|
) -> VectorStoreFileDeleteResponse:
|
||||||
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}")
|
||||||
# Route based on vector store ID
|
return await self.routing_table.openai_delete_vector_store_file(
|
||||||
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
||||||
return await provider.openai_delete_vector_store_file(
|
|
||||||
vector_store_id=vector_store_id,
|
vector_store_id=vector_store_id,
|
||||||
file_id=file_id,
|
file_id=file_id,
|
||||||
)
|
)
|
||||||
|
|
|
@ -9,6 +9,7 @@ from typing import Any
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import ScoringFn
|
from llama_stack.apis.scoring_functions import ScoringFn
|
||||||
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||||
|
from llama_stack.distribution.access_control.datatypes import Action
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
AccessRule,
|
AccessRule,
|
||||||
RoutableObject,
|
RoutableObject,
|
||||||
|
@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable):
|
||||||
await self.dist_registry.register(obj)
|
await self.dist_registry.register(obj)
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
async def assert_action_allowed(
|
||||||
|
self,
|
||||||
|
action: Action,
|
||||||
|
type: str,
|
||||||
|
identifier: str,
|
||||||
|
) -> None:
|
||||||
|
"""Fetch a registered object by type/identifier and enforce the given action permission."""
|
||||||
|
obj = await self.get_object_by_identifier(type, identifier)
|
||||||
|
if obj is None:
|
||||||
|
raise ValueError(f"{type.capitalize()} '{identifier}' not found")
|
||||||
|
user = get_authenticated_user()
|
||||||
|
if not is_action_allowed(self.policy, action, obj, user):
|
||||||
|
raise AccessDeniedError(action, obj, user)
|
||||||
|
|
||||||
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]:
|
||||||
objs = await self.dist_registry.get_all()
|
objs = await self.dist_registry.get_all()
|
||||||
filtered_objs = [obj for obj in objs if obj.type == type]
|
filtered_objs = [obj for obj in objs if obj.type == type]
|
||||||
|
|
|
@ -4,11 +4,24 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
||||||
|
from llama_stack.apis.vector_io.vector_io import (
|
||||||
|
SearchRankingOptions,
|
||||||
|
VectorStoreChunkingStrategy,
|
||||||
|
VectorStoreDeleteResponse,
|
||||||
|
VectorStoreFileContentsResponse,
|
||||||
|
VectorStoreFileDeleteResponse,
|
||||||
|
VectorStoreFileObject,
|
||||||
|
VectorStoreFileStatus,
|
||||||
|
VectorStoreObject,
|
||||||
|
VectorStoreSearchResponsePage,
|
||||||
|
)
|
||||||
from llama_stack.distribution.datatypes import (
|
from llama_stack.distribution.datatypes import (
|
||||||
VectorDBWithOwner,
|
VectorDBWithOwner,
|
||||||
)
|
)
|
||||||
|
@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
if existing_vector_db is None:
|
if existing_vector_db is None:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
await self.unregister_object(existing_vector_db)
|
await self.unregister_object(existing_vector_db)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
|
||||||
|
|
||||||
|
async def openai_update_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
expires_after: dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
name=name,
|
||||||
|
expires_after=expires_after,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreDeleteResponse:
|
||||||
|
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||||
|
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
|
||||||
|
await self.unregister_vector_db(vector_store_id)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def openai_search_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: str | list[str],
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
max_num_results: int | None = 10,
|
||||||
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
|
rewrite_query: bool | None = False,
|
||||||
|
search_mode: str | None = "vector",
|
||||||
|
) -> VectorStoreSearchResponsePage:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
query=query,
|
||||||
|
filters=filters,
|
||||||
|
max_num_results=max_num_results,
|
||||||
|
ranking_options=ranking_options,
|
||||||
|
rewrite_query=rewrite_query,
|
||||||
|
search_mode=search_mode,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_attach_file_to_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
attributes: dict[str, Any] | None = None,
|
||||||
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||||
|
) -> VectorStoreFileObject:
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
|
chunking_strategy=chunking_strategy,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_files_in_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
limit: int | None = 20,
|
||||||
|
order: str | None = "desc",
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
filter: VectorStoreFileStatus | None = None,
|
||||||
|
) -> list[VectorStoreFileObject]:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
limit=limit,
|
||||||
|
order=order,
|
||||||
|
after=after,
|
||||||
|
before=before,
|
||||||
|
filter=filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileObject:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file_contents(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileContentsResponse:
|
||||||
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_update_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
attributes: dict[str, Any],
|
||||||
|
) -> VectorStoreFileObject:
|
||||||
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileDeleteResponse:
|
||||||
|
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
||||||
|
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
file_id=file_id,
|
||||||
|
)
|
||||||
|
|
|
@ -11,17 +11,15 @@ from unittest.mock import AsyncMock
|
||||||
from llama_stack.apis.common.type_system import NumberType
|
from llama_stack.apis.common.type_system import NumberType
|
||||||
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
|
||||||
from llama_stack.apis.datatypes import Api
|
from llama_stack.apis.datatypes import Api
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.shields.shields import Shield
|
from llama_stack.apis.shields.shields import Shield
|
||||||
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter
|
||||||
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
|
||||||
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
from llama_stack.distribution.routing_tables.models import ModelsRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
|
from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
|
from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable
|
||||||
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
|
|
||||||
|
|
||||||
|
|
||||||
class Impl:
|
class Impl:
|
||||||
|
@ -54,17 +52,6 @@ class SafetyImpl(Impl):
|
||||||
return shield
|
return shield
|
||||||
|
|
||||||
|
|
||||||
class VectorDBImpl(Impl):
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(Api.vector_io)
|
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB):
|
|
||||||
return vector_db
|
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str):
|
|
||||||
return vector_db_id
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetsImpl(Impl):
|
class DatasetsImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(Api.datasetio)
|
super().__init__(Api.datasetio)
|
||||||
|
@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
|
||||||
assert "test-shield-2" in shield_ids
|
assert "test-shield-2" in shield_ids
|
||||||
|
|
||||||
|
|
||||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
|
||||||
await table.initialize()
|
|
||||||
|
|
||||||
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
|
||||||
await m_table.initialize()
|
|
||||||
await m_table.register_model(
|
|
||||||
model_id="test-model",
|
|
||||||
provider_id="test_provider",
|
|
||||||
metadata={"embedding_dimension": 128},
|
|
||||||
model_type=ModelType.embedding,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register multiple vector databases and verify listing
|
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
|
||||||
|
|
||||||
assert len(vector_dbs.data) == 2
|
|
||||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
|
||||||
assert "test-vectordb" in vector_db_ids
|
|
||||||
assert "test-vectordb-2" in vector_db_ids
|
|
||||||
|
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
|
||||||
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
|
||||||
assert len(vector_dbs.data) == 0
|
|
||||||
|
|
||||||
|
|
||||||
async def test_datasets_routing_table(cached_disk_dist_registry):
|
async def test_datasets_routing_table(cached_disk_dist_registry):
|
||||||
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
5
tests/unit/distribution/routing_tables/__init__.py
Normal file
5
tests/unit/distribution/routing_tables/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
274
tests/unit/distribution/routing_tables/test_vector_dbs.py
Normal file
274
tests/unit/distribution/routing_tables/test_vector_dbs.py
Normal file
|
@ -0,0 +1,274 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Unit tests for the routing tables vector_dbs
|
||||||
|
|
||||||
|
import time
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.datatypes import Api
|
||||||
|
from llama_stack.apis.models import ModelType
|
||||||
|
from llama_stack.apis.vector_dbs.vector_dbs import VectorDB
|
||||||
|
from llama_stack.apis.vector_io.vector_io import (
|
||||||
|
VectorStoreContent,
|
||||||
|
VectorStoreDeleteResponse,
|
||||||
|
VectorStoreFileContentsResponse,
|
||||||
|
VectorStoreFileCounts,
|
||||||
|
VectorStoreFileDeleteResponse,
|
||||||
|
VectorStoreFileObject,
|
||||||
|
VectorStoreObject,
|
||||||
|
VectorStoreSearchResponsePage,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.access_control.datatypes import AccessRule, Scope
|
||||||
|
from llama_stack.distribution.datatypes import User
|
||||||
|
from llama_stack.distribution.request_headers import request_provider_data_context
|
||||||
|
from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable
|
||||||
|
from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable
|
||||||
|
|
||||||
|
|
||||||
|
class VectorDBImpl(Impl):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__(Api.vector_io)
|
||||||
|
|
||||||
|
async def register_vector_db(self, vector_db: VectorDB):
|
||||||
|
return vector_db
|
||||||
|
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str):
|
||||||
|
return vector_db_id
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store(self, vector_store_id):
|
||||||
|
return VectorStoreObject(
|
||||||
|
id=vector_store_id,
|
||||||
|
name="Test Store",
|
||||||
|
created_at=int(time.time()),
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_update_vector_store(self, vector_store_id, **kwargs):
|
||||||
|
return VectorStoreObject(
|
||||||
|
id=vector_store_id,
|
||||||
|
name="Updated Store",
|
||||||
|
created_at=int(time.time()),
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store(self, vector_store_id):
|
||||||
|
return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True)
|
||||||
|
|
||||||
|
async def openai_search_vector_store(self, vector_store_id, query, **kwargs):
|
||||||
|
return VectorStoreSearchResponsePage(
|
||||||
|
object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs):
|
||||||
|
return VectorStoreFileObject(
|
||||||
|
id=file_id,
|
||||||
|
status="completed",
|
||||||
|
chunking_strategy={"type": "auto"},
|
||||||
|
created_at=int(time.time()),
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs):
|
||||||
|
return [
|
||||||
|
VectorStoreFileObject(
|
||||||
|
id="1",
|
||||||
|
status="completed",
|
||||||
|
chunking_strategy={"type": "auto"},
|
||||||
|
created_at=int(time.time()),
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file(self, vector_store_id, file_id):
|
||||||
|
return VectorStoreFileObject(
|
||||||
|
id=file_id,
|
||||||
|
status="completed",
|
||||||
|
chunking_strategy={"type": "auto"},
|
||||||
|
created_at=int(time.time()),
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id):
|
||||||
|
return VectorStoreFileContentsResponse(
|
||||||
|
file_id=file_id,
|
||||||
|
filename="Sample File name",
|
||||||
|
attributes={"key": "value"},
|
||||||
|
content=[VectorStoreContent(type="text", text="Sample content")],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs):
|
||||||
|
return VectorStoreFileObject(
|
||||||
|
id=file_id,
|
||||||
|
status="completed",
|
||||||
|
chunking_strategy={"type": "auto"},
|
||||||
|
created_at=int(time.time()),
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||||
|
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await m_table.initialize()
|
||||||
|
await m_table.register_model(
|
||||||
|
model_id="test-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 128},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register multiple vector databases and verify listing
|
||||||
|
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
||||||
|
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
||||||
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
|
||||||
|
assert len(vector_dbs.data) == 2
|
||||||
|
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||||
|
assert "test-vectordb" in vector_db_ids
|
||||||
|
assert "test-vectordb-2" in vector_db_ids
|
||||||
|
|
||||||
|
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
||||||
|
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
||||||
|
|
||||||
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
assert len(vector_dbs.data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||||
|
impl = VectorDBImpl()
|
||||||
|
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||||
|
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[])
|
||||||
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||||
|
authorized_table = "vs1"
|
||||||
|
authorized_team = "team1"
|
||||||
|
unauthorized_team = "team2"
|
||||||
|
|
||||||
|
await m_table.initialize()
|
||||||
|
await m_table.register_model(
|
||||||
|
model_id="test-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 128},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||||
|
with request_provider_data_context({}, authorized_user):
|
||||||
|
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||||
|
|
||||||
|
# Authorized reader
|
||||||
|
with request_provider_data_context({}, authorized_user):
|
||||||
|
res = await table.openai_retrieve_vector_store(authorized_table)
|
||||||
|
assert res == "OK"
|
||||||
|
|
||||||
|
# Authorized updater
|
||||||
|
impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED")
|
||||||
|
with request_provider_data_context({}, authorized_user):
|
||||||
|
res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||||
|
assert res == "UPDATED"
|
||||||
|
|
||||||
|
# Unauthorized reader
|
||||||
|
unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]})
|
||||||
|
with request_provider_data_context({}, unauthorized_user):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await table.openai_retrieve_vector_store(authorized_table)
|
||||||
|
|
||||||
|
# Unauthorized updater
|
||||||
|
with request_provider_data_context({}, unauthorized_user):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"})
|
||||||
|
|
||||||
|
# Authorized deleter
|
||||||
|
impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED")
|
||||||
|
with request_provider_data_context({}, authorized_user):
|
||||||
|
res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||||
|
assert res == "DELETED"
|
||||||
|
|
||||||
|
# Unauthorized deleter
|
||||||
|
with request_provider_data_context({}, unauthorized_user):
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await table.openai_delete_vector_store_file(authorized_table, file_id="file1")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry):
|
||||||
|
impl = VectorDBImpl()
|
||||||
|
|
||||||
|
policy = [
|
||||||
|
AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"),
|
||||||
|
AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"),
|
||||||
|
]
|
||||||
|
|
||||||
|
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy)
|
||||||
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[])
|
||||||
|
|
||||||
|
vector_db_id = "vs1"
|
||||||
|
file_id = "file-1"
|
||||||
|
|
||||||
|
admin_user = User(principal="admin", attributes={"roles": ["admin"]})
|
||||||
|
read_only_user = User(principal="reader", attributes={"roles": ["reader"]})
|
||||||
|
no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]})
|
||||||
|
|
||||||
|
await m_table.initialize()
|
||||||
|
await m_table.register_model(
|
||||||
|
model_id="test-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 128},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
with request_provider_data_context({}, admin_user):
|
||||||
|
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||||
|
|
||||||
|
read_methods = [
|
||||||
|
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||||
|
(table.openai_search_vector_store, (vector_db_id, "query"), {}),
|
||||||
|
(table.openai_list_files_in_vector_store, (vector_db_id,), {}),
|
||||||
|
(table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}),
|
||||||
|
(table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}),
|
||||||
|
]
|
||||||
|
update_methods = [
|
||||||
|
(table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}),
|
||||||
|
(table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}),
|
||||||
|
(table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}),
|
||||||
|
]
|
||||||
|
delete_methods = [
|
||||||
|
(table.openai_delete_vector_store_file, (vector_db_id, file_id), {}),
|
||||||
|
(table.openai_delete_vector_store, (vector_db_id,), {}),
|
||||||
|
]
|
||||||
|
|
||||||
|
for user in [admin_user, read_only_user]:
|
||||||
|
with request_provider_data_context({}, user):
|
||||||
|
for method, args, kwargs in read_methods:
|
||||||
|
result = await method(*args, **kwargs)
|
||||||
|
assert result is not None, f"Read operation failed with user {user.principal}"
|
||||||
|
|
||||||
|
with request_provider_data_context({}, no_access_user):
|
||||||
|
for method, args, kwargs in read_methods:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await method(*args, **kwargs)
|
||||||
|
|
||||||
|
with request_provider_data_context({}, admin_user):
|
||||||
|
for method, args, kwargs in update_methods:
|
||||||
|
result = await method(*args, **kwargs)
|
||||||
|
assert result is not None, "Update operation failed with admin user"
|
||||||
|
|
||||||
|
with request_provider_data_context({}, admin_user):
|
||||||
|
for method, args, kwargs in delete_methods:
|
||||||
|
result = await method(*args, **kwargs)
|
||||||
|
assert result is not None, "Delete operation failed with admin user"
|
||||||
|
|
||||||
|
for user in [read_only_user, no_access_user]:
|
||||||
|
with request_provider_data_context({}, user):
|
||||||
|
for method, args, kwargs in delete_methods:
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
await method(*args, **kwargs)
|
Loading…
Add table
Add a link
Reference in a new issue