mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-23 04:53:14 +00:00
# What does this PR do? Refactors the vector store routing logic by moving OpenAI-compatible vector store operations from the `VectorIORouter` to the `VectorDBsRoutingTable`. Closes https://github.com/meta-llama/llama-stack/issues/2761 ## Test Plan Added unit tests to cover new routing logic and ACL checks. --------- Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
221 lines
8.6 KiB
Python
221 lines
8.6 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
|
|
from pydantic import TypeAdapter
|
|
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.apis.resource import ResourceType
|
|
from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs
|
|
from llama_stack.apis.vector_io.vector_io import (
|
|
SearchRankingOptions,
|
|
VectorStoreChunkingStrategy,
|
|
VectorStoreDeleteResponse,
|
|
VectorStoreFileContentsResponse,
|
|
VectorStoreFileDeleteResponse,
|
|
VectorStoreFileObject,
|
|
VectorStoreFileStatus,
|
|
VectorStoreObject,
|
|
VectorStoreSearchResponsePage,
|
|
)
|
|
from llama_stack.distribution.datatypes import (
|
|
VectorDBWithOwner,
|
|
)
|
|
from llama_stack.log import get_logger
|
|
|
|
from .common import CommonRoutingTableImpl
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|
async def list_vector_dbs(self) -> ListVectorDBsResponse:
|
|
return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db"))
|
|
|
|
async def get_vector_db(self, vector_db_id: str) -> VectorDB:
|
|
vector_db = await self.get_object_by_identifier("vector_db", vector_db_id)
|
|
if vector_db is None:
|
|
raise ValueError(f"Vector DB '{vector_db_id}' not found")
|
|
return vector_db
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db_id: str,
|
|
embedding_model: str,
|
|
embedding_dimension: int | None = 384,
|
|
provider_id: str | None = None,
|
|
provider_vector_db_id: str | None = None,
|
|
vector_db_name: str | None = None,
|
|
) -> VectorDB:
|
|
if provider_vector_db_id is None:
|
|
provider_vector_db_id = vector_db_id
|
|
if provider_id is None:
|
|
if len(self.impls_by_provider_id) > 0:
|
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
|
if len(self.impls_by_provider_id) > 1:
|
|
logger.warning(
|
|
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
|
)
|
|
else:
|
|
raise ValueError("No provider available. Please configure a vector_io provider.")
|
|
model = await self.get_object_by_identifier("model", embedding_model)
|
|
if model is None:
|
|
raise ValueError(f"Model {embedding_model} not found")
|
|
if model.model_type != ModelType.embedding:
|
|
raise ValueError(f"Model {embedding_model} is not an embedding model")
|
|
if "embedding_dimension" not in model.metadata:
|
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
|
vector_db_data = {
|
|
"identifier": vector_db_id,
|
|
"type": ResourceType.vector_db.value,
|
|
"provider_id": provider_id,
|
|
"provider_resource_id": provider_vector_db_id,
|
|
"embedding_model": embedding_model,
|
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
|
"vector_db_name": vector_db_name,
|
|
}
|
|
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
|
await self.register_object(vector_db)
|
|
return vector_db
|
|
|
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
|
existing_vector_db = await self.get_vector_db(vector_db_id)
|
|
if existing_vector_db is None:
|
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
|
await self.unregister_object(existing_vector_db)
|
|
|
|
async def openai_retrieve_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreObject:
|
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id)
|
|
|
|
async def openai_update_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
name: str | None = None,
|
|
expires_after: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> VectorStoreObject:
|
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_update_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
name=name,
|
|
expires_after=expires_after,
|
|
metadata=metadata,
|
|
)
|
|
|
|
async def openai_delete_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreDeleteResponse:
|
|
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
|
result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id)
|
|
await self.unregister_vector_db(vector_store_id)
|
|
return result
|
|
|
|
async def openai_search_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
query: str | list[str],
|
|
filters: dict[str, Any] | None = None,
|
|
max_num_results: int | None = 10,
|
|
ranking_options: SearchRankingOptions | None = None,
|
|
rewrite_query: bool | None = False,
|
|
search_mode: str | None = "vector",
|
|
) -> VectorStoreSearchResponsePage:
|
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_search_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
query=query,
|
|
filters=filters,
|
|
max_num_results=max_num_results,
|
|
ranking_options=ranking_options,
|
|
rewrite_query=rewrite_query,
|
|
search_mode=search_mode,
|
|
)
|
|
|
|
async def openai_attach_file_to_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
attributes: dict[str, Any] | None = None,
|
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
|
) -> VectorStoreFileObject:
|
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
attributes=attributes,
|
|
chunking_strategy=chunking_strategy,
|
|
)
|
|
|
|
async def openai_list_files_in_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
limit: int | None = 20,
|
|
order: str | None = "desc",
|
|
after: str | None = None,
|
|
before: str | None = None,
|
|
filter: VectorStoreFileStatus | None = None,
|
|
) -> list[VectorStoreFileObject]:
|
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
limit=limit,
|
|
order=order,
|
|
after=after,
|
|
before=before,
|
|
filter=filter,
|
|
)
|
|
|
|
async def openai_retrieve_vector_store_file(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
) -> VectorStoreFileObject:
|
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
)
|
|
|
|
async def openai_retrieve_vector_store_file_contents(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
) -> VectorStoreFileContentsResponse:
|
|
await self.assert_action_allowed("read", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
)
|
|
|
|
async def openai_update_vector_store_file(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
attributes: dict[str, Any],
|
|
) -> VectorStoreFileObject:
|
|
await self.assert_action_allowed("update", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
attributes=attributes,
|
|
)
|
|
|
|
async def openai_delete_vector_store_file(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
) -> VectorStoreFileDeleteResponse:
|
|
await self.assert_action_allowed("delete", "vector_db", vector_store_id)
|
|
return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
)
|