mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
When trying to `list` vector_stores , if we cannot retrieve one, log an error and return all the ones that are valid. ### Test Plan ``` pytest -sv --stack-config=http://localhost:8321 tests/integration/vector_io/test_openai_vector_stores.py --embedding-model all-MiniLM-L6-v2 ``` Also tested for `--stack-config fireworks`
278 lines
11 KiB
Python
278 lines
11 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from typing import Any
|
|
|
|
from llama_stack.apis.common.content_types import (
|
|
InterleavedContent,
|
|
)
|
|
from llama_stack.apis.models import ModelType
|
|
from llama_stack.apis.vector_io import (
|
|
Chunk,
|
|
QueryChunksResponse,
|
|
VectorIO,
|
|
VectorStoreDeleteResponse,
|
|
VectorStoreListResponse,
|
|
VectorStoreObject,
|
|
VectorStoreSearchResponsePage,
|
|
)
|
|
from llama_stack.apis.vector_io.vector_io import VectorStoreChunkingStrategy, VectorStoreFileObject
|
|
from llama_stack.log import get_logger
|
|
from llama_stack.providers.datatypes import RoutingTable
|
|
|
|
logger = get_logger(name=__name__, category="core")
|
|
|
|
|
|
class VectorIORouter(VectorIO):
|
|
"""Routes to an provider based on the vector db identifier"""
|
|
|
|
def __init__(
|
|
self,
|
|
routing_table: RoutingTable,
|
|
) -> None:
|
|
logger.debug("Initializing VectorIORouter")
|
|
self.routing_table = routing_table
|
|
|
|
async def initialize(self) -> None:
|
|
logger.debug("VectorIORouter.initialize")
|
|
pass
|
|
|
|
async def shutdown(self) -> None:
|
|
logger.debug("VectorIORouter.shutdown")
|
|
pass
|
|
|
|
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
|
|
"""Get the first available embedding model identifier."""
|
|
try:
|
|
# Get all models from the routing table
|
|
all_models = await self.routing_table.get_all_with_type("model")
|
|
|
|
# Filter for embedding models
|
|
embedding_models = [
|
|
model
|
|
for model in all_models
|
|
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
|
|
]
|
|
|
|
if embedding_models:
|
|
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
|
|
if dimension is None:
|
|
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
|
|
return embedding_models[0].identifier, dimension
|
|
else:
|
|
logger.warning("No embedding models found in the routing table")
|
|
return None
|
|
except Exception as e:
|
|
logger.error(f"Error getting embedding models: {e}")
|
|
return None
|
|
|
|
async def register_vector_db(
|
|
self,
|
|
vector_db_id: str,
|
|
embedding_model: str,
|
|
embedding_dimension: int | None = 384,
|
|
provider_id: str | None = None,
|
|
provider_vector_db_id: str | None = None,
|
|
) -> None:
|
|
logger.debug(f"VectorIORouter.register_vector_db: {vector_db_id}, {embedding_model}")
|
|
await self.routing_table.register_vector_db(
|
|
vector_db_id,
|
|
embedding_model,
|
|
embedding_dimension,
|
|
provider_id,
|
|
provider_vector_db_id,
|
|
)
|
|
|
|
async def insert_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
chunks: list[Chunk],
|
|
ttl_seconds: int | None = None,
|
|
) -> None:
|
|
logger.debug(
|
|
f"VectorIORouter.insert_chunks: {vector_db_id}, {len(chunks)} chunks, ttl_seconds={ttl_seconds}, chunk_ids={[chunk.metadata['document_id'] for chunk in chunks[:3]]}{' and more...' if len(chunks) > 3 else ''}",
|
|
)
|
|
return await self.routing_table.get_provider_impl(vector_db_id).insert_chunks(vector_db_id, chunks, ttl_seconds)
|
|
|
|
async def query_chunks(
|
|
self,
|
|
vector_db_id: str,
|
|
query: InterleavedContent,
|
|
params: dict[str, Any] | None = None,
|
|
) -> QueryChunksResponse:
|
|
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
|
|
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
|
|
|
|
# OpenAI Vector Stores API endpoints
|
|
async def openai_create_vector_store(
|
|
self,
|
|
name: str,
|
|
file_ids: list[str] | None = None,
|
|
expires_after: dict[str, Any] | None = None,
|
|
chunking_strategy: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
embedding_model: str | None = None,
|
|
embedding_dimension: int | None = None,
|
|
provider_id: str | None = None,
|
|
provider_vector_db_id: str | None = None,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
|
|
|
|
# If no embedding model is provided, use the first available one
|
|
if embedding_model is None:
|
|
embedding_model_info = await self._get_first_embedding_model()
|
|
if embedding_model_info is None:
|
|
raise ValueError("No embedding model provided and no embedding models available in the system")
|
|
embedding_model, embedding_dimension = embedding_model_info
|
|
logger.info(f"No embedding model specified, using first available: {embedding_model}")
|
|
|
|
vector_db_id = name
|
|
registered_vector_db = await self.routing_table.register_vector_db(
|
|
vector_db_id,
|
|
embedding_model,
|
|
embedding_dimension,
|
|
provider_id,
|
|
provider_vector_db_id,
|
|
)
|
|
|
|
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
|
|
vector_db_id,
|
|
file_ids=file_ids,
|
|
expires_after=expires_after,
|
|
chunking_strategy=chunking_strategy,
|
|
metadata=metadata,
|
|
embedding_model=embedding_model,
|
|
embedding_dimension=embedding_dimension,
|
|
provider_id=registered_vector_db.provider_id,
|
|
provider_vector_db_id=registered_vector_db.provider_resource_id,
|
|
)
|
|
|
|
async def openai_list_vector_stores(
|
|
self,
|
|
limit: int | None = 20,
|
|
order: str | None = "desc",
|
|
after: str | None = None,
|
|
before: str | None = None,
|
|
) -> VectorStoreListResponse:
|
|
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
|
|
# Route to default provider for now - could aggregate from all providers in the future
|
|
# call retrieve on each vector dbs to get list of vector stores
|
|
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
|
|
all_stores = []
|
|
for vector_db in vector_dbs:
|
|
try:
|
|
vector_store = await self.routing_table.get_provider_impl(
|
|
vector_db.identifier
|
|
).openai_retrieve_vector_store(vector_db.identifier)
|
|
all_stores.append(vector_store)
|
|
except Exception as e:
|
|
logger.error(f"Error retrieving vector store {vector_db.identifier}: {e}")
|
|
continue
|
|
|
|
# Sort by created_at
|
|
reverse_order = order == "desc"
|
|
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
|
|
|
|
# Apply cursor-based pagination
|
|
if after:
|
|
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
|
|
if after_index >= 0:
|
|
all_stores = all_stores[after_index + 1 :]
|
|
|
|
if before:
|
|
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
|
|
all_stores = all_stores[:before_index]
|
|
|
|
# Apply limit
|
|
limited_stores = all_stores[:limit]
|
|
|
|
# Determine pagination info
|
|
has_more = len(all_stores) > limit
|
|
first_id = limited_stores[0].id if limited_stores else None
|
|
last_id = limited_stores[-1].id if limited_stores else None
|
|
|
|
return VectorStoreListResponse(
|
|
data=limited_stores,
|
|
has_more=has_more,
|
|
first_id=first_id,
|
|
last_id=last_id,
|
|
)
|
|
|
|
async def openai_retrieve_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_retrieve_vector_store(vector_store_id)
|
|
|
|
async def openai_update_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
name: str | None = None,
|
|
expires_after: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> VectorStoreObject:
|
|
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_update_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
name=name,
|
|
expires_after=expires_after,
|
|
metadata=metadata,
|
|
)
|
|
|
|
async def openai_delete_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
) -> VectorStoreDeleteResponse:
|
|
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
result = await provider.openai_delete_vector_store(vector_store_id)
|
|
# drop from registry
|
|
await self.routing_table.unregister_vector_db(vector_store_id)
|
|
return result
|
|
|
|
async def openai_search_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
query: str | list[str],
|
|
filters: dict[str, Any] | None = None,
|
|
max_num_results: int | None = 10,
|
|
ranking_options: dict[str, Any] | None = None,
|
|
rewrite_query: bool | None = False,
|
|
) -> VectorStoreSearchResponsePage:
|
|
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_search_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
query=query,
|
|
filters=filters,
|
|
max_num_results=max_num_results,
|
|
ranking_options=ranking_options,
|
|
rewrite_query=rewrite_query,
|
|
)
|
|
|
|
async def openai_attach_file_to_vector_store(
|
|
self,
|
|
vector_store_id: str,
|
|
file_id: str,
|
|
attributes: dict[str, Any] | None = None,
|
|
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
|
) -> VectorStoreFileObject:
|
|
logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}")
|
|
# Route based on vector store ID
|
|
provider = self.routing_table.get_provider_impl(vector_store_id)
|
|
return await provider.openai_attach_file_to_vector_store(
|
|
vector_store_id=vector_store_id,
|
|
file_id=file_id,
|
|
attributes=attributes,
|
|
chunking_strategy=chunking_strategy,
|
|
)
|