mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
feat: OpenAIVectorIOMixin for vector_stores common logic (#2427)
Extracts common OpenAI vector-store code into its own mixin so that all providers can share the same core logic. This also makes it easy for Llama Stack to support both vector-stores and Llama Stack APIs in the interim so that both share the same underlying vector-dbs. Each provider contains storage specific logic to `create / edit / delete / list` vector dbs while the plumbing logic is standardized in the common code. Ensured that this works well with both faiss and sqllite-vec. ### Test Plan ``` llama stack run starter pytest -sv --stack-config http://localhost:8321 tests/integration/vector_io/test_openai_vector_stores.py --embedding-model all-MiniLM-L6-v2 ```
This commit is contained in:
parent
4e37b49cdc
commit
d55100d9b7
11 changed files with 484 additions and 633 deletions
3
docs/_static/llama-stack-spec.html
vendored
3
docs/_static/llama-stack-spec.html
vendored
|
@ -12583,6 +12583,9 @@
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
"required": [
|
||||||
|
"name"
|
||||||
|
],
|
||||||
"title": "OpenaiCreateVectorStoreRequest"
|
"title": "OpenaiCreateVectorStoreRequest"
|
||||||
},
|
},
|
||||||
"VectorStoreObject": {
|
"VectorStoreObject": {
|
||||||
|
|
2
docs/_static/llama-stack-spec.yaml
vendored
2
docs/_static/llama-stack-spec.yaml
vendored
|
@ -8791,6 +8791,8 @@ components:
|
||||||
description: >-
|
description: >-
|
||||||
The provider-specific vector database ID.
|
The provider-specific vector database ID.
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- name
|
||||||
title: OpenaiCreateVectorStoreRequest
|
title: OpenaiCreateVectorStoreRequest
|
||||||
VectorStoreObject:
|
VectorStoreObject:
|
||||||
type: object
|
type: object
|
||||||
|
|
|
@ -165,7 +165,7 @@ class VectorIO(Protocol):
|
||||||
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
@webmethod(route="/openai/v1/vector_stores", method="POST")
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
|
|
@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
|
||||||
missing_methods.append((name, "signature_mismatch"))
|
missing_methods.append((name, "signature_mismatch"))
|
||||||
else:
|
else:
|
||||||
# Check if the method is actually implemented in the class
|
# Check if the method has a concrete implementation (not just a protocol stub)
|
||||||
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
|
# Find all classes in MRO that define this method
|
||||||
if method_owner is None or method_owner.__name__ == protocol.__name__:
|
method_owners = [cls for cls in mro if name in cls.__dict__]
|
||||||
|
|
||||||
|
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
|
||||||
|
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
|
||||||
|
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
|
||||||
missing_methods.append((name, "not_actually_implemented"))
|
missing_methods.append((name, "not_actually_implemented"))
|
||||||
|
|
||||||
if missing_methods:
|
if missing_methods:
|
||||||
|
|
|
@ -108,7 +108,7 @@ class VectorIORouter(VectorIO):
|
||||||
# OpenAI Vector Stores API endpoints
|
# OpenAI Vector Stores API endpoints
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
|
|
@ -9,9 +9,7 @@ import base64
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import time
|
from typing import Any
|
||||||
import uuid
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
@ -24,14 +22,11 @@ from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponse,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
@ -47,10 +42,6 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
|
||||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
# In faiss, since we do
|
|
||||||
CHUNK_MULTIPLIER = 5
|
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||||
self.index = faiss.IndexFlatL2(dimension)
|
self.index = faiss.IndexFlatL2(dimension)
|
||||||
|
@ -140,7 +131,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
raise NotImplementedError("Keyword search is not supported in FAISS")
|
raise NotImplementedError("Keyword search is not supported in FAISS")
|
||||||
|
|
||||||
|
|
||||||
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
@ -164,14 +155,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
||||||
# Load existing OpenAI vector stores
|
# Load existing OpenAI vector stores using the mixin method
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
|
|
||||||
for store_data in stored_openai_stores:
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
self.openai_vector_stores[store_info["id"]] = store_info
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# Cleanup if needed
|
# Cleanup if needed
|
||||||
|
@ -234,285 +219,34 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
# OpenAI Vector Stores API endpoints implementation
|
# OpenAI Vector Store Mixin abstract method implementations
|
||||||
async def openai_create_vector_store(
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
self,
|
"""Save vector store metadata to kvstore."""
|
||||||
name: str | None = None,
|
|
||||||
file_ids: list[str] | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
embedding_model: str | None = None,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Creates a vector store."""
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
# store and vector_db have the same id
|
|
||||||
store_id = name or str(uuid.uuid4())
|
|
||||||
created_at = int(time.time())
|
|
||||||
|
|
||||||
if provider_id is None:
|
|
||||||
raise ValueError("Provider ID is required")
|
|
||||||
|
|
||||||
if embedding_model is None:
|
|
||||||
raise ValueError("Embedding model is required")
|
|
||||||
|
|
||||||
# Use provided embedding dimension or default to 384
|
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError("Embedding dimension is required")
|
|
||||||
|
|
||||||
provider_vector_db_id = provider_vector_db_id or store_id
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier=store_id,
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the vector DB
|
|
||||||
await self.register_vector_db(vector_db)
|
|
||||||
|
|
||||||
# Create OpenAI vector store metadata
|
|
||||||
store_info = {
|
|
||||||
"id": store_id,
|
|
||||||
"object": "vector_store",
|
|
||||||
"created_at": created_at,
|
|
||||||
"name": store_id,
|
|
||||||
"usage_bytes": 0,
|
|
||||||
"file_counts": {},
|
|
||||||
"status": "completed",
|
|
||||||
"expires_after": expires_after,
|
|
||||||
"expires_at": None,
|
|
||||||
"last_active_at": created_at,
|
|
||||||
"file_ids": file_ids or [],
|
|
||||||
"chunking_strategy": chunking_strategy,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add provider information to metadata if provided
|
|
||||||
metadata = metadata or {}
|
|
||||||
if provider_id:
|
|
||||||
metadata["provider_id"] = provider_id
|
|
||||||
if provider_vector_db_id:
|
|
||||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Store in kvstore
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
|
||||||
# Store in memory cache
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
self.openai_vector_stores[store_id] = store_info
|
"""Load all vector store metadata from kvstore."""
|
||||||
|
|
||||||
return VectorStoreObject(
|
|
||||||
id=store_id,
|
|
||||||
created_at=created_at,
|
|
||||||
name=store_id,
|
|
||||||
usage_bytes=0,
|
|
||||||
file_counts={},
|
|
||||||
status="completed",
|
|
||||||
expires_after=expires_after,
|
|
||||||
expires_at=None,
|
|
||||||
last_active_at=created_at,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
|
||||||
self,
|
|
||||||
limit: int = 20,
|
|
||||||
order: str = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
) -> VectorStoreListResponse:
|
|
||||||
"""Returns a list of vector stores."""
|
|
||||||
# Get all vector stores
|
|
||||||
all_stores = list(self.openai_vector_stores.values())
|
|
||||||
|
|
||||||
# Sort by created_at
|
|
||||||
reverse_order = order == "desc"
|
|
||||||
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
|
|
||||||
|
|
||||||
# Apply cursor-based pagination
|
|
||||||
if after:
|
|
||||||
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
|
|
||||||
if after_index >= 0:
|
|
||||||
all_stores = all_stores[after_index + 1 :]
|
|
||||||
|
|
||||||
if before:
|
|
||||||
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
|
||||||
all_stores = all_stores[:before_index]
|
|
||||||
|
|
||||||
# Apply limit
|
|
||||||
limited_stores = all_stores[:limit]
|
|
||||||
# Convert to VectorStoreObject instances
|
|
||||||
data = [VectorStoreObject(**store) for store in limited_stores]
|
|
||||||
|
|
||||||
# Determine pagination info
|
|
||||||
has_more = len(all_stores) > limit
|
|
||||||
first_id = data[0].id if data else None
|
|
||||||
last_id = data[-1].id if data else None
|
|
||||||
|
|
||||||
return VectorStoreListResponse(
|
|
||||||
data=data,
|
|
||||||
has_more=has_more,
|
|
||||||
first_id=first_id,
|
|
||||||
last_id=last_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Retrieves a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id]
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
name: str | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Modifies a vector store."""
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||||
|
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
stores = {}
|
||||||
|
for store_data in stored_openai_stores:
|
||||||
|
store_info = json.loads(store_data)
|
||||||
|
stores[store_info["id"]] = store_info
|
||||||
|
return stores
|
||||||
|
|
||||||
# Update fields if provided
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
if name is not None:
|
"""Update vector store metadata in kvstore."""
|
||||||
store_info["name"] = name
|
assert self.kvstore is not None
|
||||||
if expires_after is not None:
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
store_info["expires_after"] = expires_after
|
|
||||||
if metadata is not None:
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Update last_active_at
|
|
||||||
store_info["last_active_at"] = int(time.time())
|
|
||||||
|
|
||||||
# Save to kvstore
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
|
||||||
# Update in-memory cache
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
self.openai_vector_stores[vector_store_id] = store_info
|
"""Delete vector store metadata from kvstore."""
|
||||||
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_delete_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreDeleteResponse:
|
|
||||||
"""Delete a vector store."""
|
|
||||||
assert self.kvstore is not None
|
assert self.kvstore is not None
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
# Delete from kvstore
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
await self.kvstore.delete(key)
|
||||||
|
|
||||||
# Delete from in-memory cache
|
|
||||||
del self.openai_vector_stores[vector_store_id]
|
|
||||||
|
|
||||||
# Also delete the underlying vector DB
|
|
||||||
try:
|
|
||||||
await self.unregister_vector_db(vector_store_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
|
||||||
|
|
||||||
return VectorStoreDeleteResponse(
|
|
||||||
id=vector_store_id,
|
|
||||||
deleted=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
query: str | list[str],
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
max_num_results: int = 10,
|
|
||||||
ranking_options: dict[str, Any] | None = None,
|
|
||||||
rewrite_query: bool = False,
|
|
||||||
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
|
||||||
) -> VectorStoreSearchResponse:
|
|
||||||
"""Search for chunks in a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
if isinstance(query, list):
|
|
||||||
search_query = " ".join(query)
|
|
||||||
else:
|
|
||||||
search_query = query
|
|
||||||
|
|
||||||
try:
|
|
||||||
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
|
||||||
params = {
|
|
||||||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
|
||||||
"score_threshold": score_threshold,
|
|
||||||
"mode": search_mode,
|
|
||||||
}
|
|
||||||
# TODO: Add support for ranking_options.ranker
|
|
||||||
|
|
||||||
response = await self.query_chunks(
|
|
||||||
vector_db_id=vector_store_id,
|
|
||||||
query=search_query,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert response to OpenAI format
|
|
||||||
data = []
|
|
||||||
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
|
|
||||||
# Apply score based filtering
|
|
||||||
if score < score_threshold:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Apply filters if provided
|
|
||||||
if filters:
|
|
||||||
# Simple metadata filtering
|
|
||||||
if not self._matches_filters(chunk.metadata, filters):
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk_data = {
|
|
||||||
"id": f"chunk_{i}",
|
|
||||||
"object": "vector_store.search_result",
|
|
||||||
"score": score,
|
|
||||||
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
|
|
||||||
"metadata": chunk.metadata,
|
|
||||||
}
|
|
||||||
data.append(chunk_data)
|
|
||||||
if len(data) >= max_num_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=data,
|
|
||||||
has_more=False, # For simplicity, we don't implement pagination here
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
|
||||||
# Return empty results on error
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=[],
|
|
||||||
has_more=False,
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
|
||||||
"""Check if metadata matches the provided filters."""
|
|
||||||
for key, value in filters.items():
|
|
||||||
if key not in metadata:
|
|
||||||
return False
|
|
||||||
if metadata[key] != value:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
|
@ -10,9 +10,8 @@ import json
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Literal
|
from typing import Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import sqlite_vec
|
import sqlite_vec
|
||||||
|
@ -24,12 +23,9 @@ from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponse,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
@ -39,11 +35,6 @@ VECTOR_SEARCH = "vector"
|
||||||
KEYWORD_SEARCH = "keyword"
|
KEYWORD_SEARCH = "keyword"
|
||||||
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
|
||||||
|
|
||||||
# Constants for OpenAI vector stores (similar to faiss)
|
|
||||||
VERSION = "v3"
|
|
||||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
|
||||||
CHUNK_MULTIPLIER = 5
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_vector(vector: list[float]) -> bytes:
|
def serialize_vector(vector: list[float]) -> bytes:
|
||||||
"""Serialize a list of floats into a compact binary representation."""
|
"""Serialize a list of floats into a compact binary representation."""
|
||||||
|
@ -303,7 +294,7 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
A VectorIO implementation using SQLite + sqlite_vec.
|
A VectorIO implementation using SQLite + sqlite_vec.
|
||||||
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
|
||||||
|
@ -340,15 +331,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
# Load any existing vector DB registrations.
|
# Load any existing vector DB registrations.
|
||||||
cur.execute("SELECT metadata FROM vector_dbs")
|
cur.execute("SELECT metadata FROM vector_dbs")
|
||||||
vector_db_rows = cur.fetchall()
|
vector_db_rows = cur.fetchall()
|
||||||
# Load any existing OpenAI vector stores.
|
return vector_db_rows
|
||||||
cur.execute("SELECT metadata FROM openai_vector_stores")
|
|
||||||
openai_store_rows = cur.fetchall()
|
|
||||||
return vector_db_rows, openai_store_rows
|
|
||||||
finally:
|
finally:
|
||||||
cur.close()
|
cur.close()
|
||||||
connection.close()
|
connection.close()
|
||||||
|
|
||||||
vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection)
|
vector_db_rows = await asyncio.to_thread(_setup_connection)
|
||||||
|
|
||||||
# Load existing vector DBs
|
# Load existing vector DBs
|
||||||
for row in vector_db_rows:
|
for row in vector_db_rows:
|
||||||
|
@ -359,11 +347,8 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
# Load existing OpenAI vector stores
|
# Load existing OpenAI vector stores using the mixin method
|
||||||
for row in openai_store_rows:
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
store_data = row[0]
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
self.openai_vector_stores[store_info["id"]] = store_info
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# nothing to do since we don't maintain a persistent connection
|
# nothing to do since we don't maintain a persistent connection
|
||||||
|
@ -409,6 +394,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
await asyncio.to_thread(_delete_vector_db_from_registry)
|
await asyncio.to_thread(_delete_vector_db_from_registry)
|
||||||
|
|
||||||
|
# OpenAI Vector Store Mixin abstract method implementations
|
||||||
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Save vector store metadata to SQLite database."""
|
||||||
|
|
||||||
|
def _store():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
|
||||||
|
(store_id, json.dumps(store_info)),
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(_store)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error saving openai vector store {store_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Load all vector store metadata from SQLite database."""
|
||||||
|
|
||||||
|
def _load():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute("SELECT metadata FROM openai_vector_stores")
|
||||||
|
rows = cur.fetchall()
|
||||||
|
return rows
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
rows = await asyncio.to_thread(_load)
|
||||||
|
stores = {}
|
||||||
|
for row in rows:
|
||||||
|
store_data = row[0]
|
||||||
|
store_info = json.loads(store_data)
|
||||||
|
stores[store_info["id"]] = store_info
|
||||||
|
return stores
|
||||||
|
|
||||||
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store metadata in SQLite database."""
|
||||||
|
|
||||||
|
def _update():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
|
||||||
|
(json.dumps(store_info), store_id),
|
||||||
|
)
|
||||||
|
connection.commit()
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
await asyncio.to_thread(_update)
|
||||||
|
|
||||||
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
|
"""Delete vector store metadata from SQLite database."""
|
||||||
|
|
||||||
|
def _delete():
|
||||||
|
connection = _create_sqlite_connection(self.config.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
|
||||||
|
connection.commit()
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
await asyncio.to_thread(_delete)
|
||||||
|
|
||||||
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
|
||||||
if vector_db_id not in self.cache:
|
if vector_db_id not in self.cache:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
|
||||||
|
@ -423,318 +489,6 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
return await self.cache[vector_db_id].query_chunks(query, params)
|
return await self.cache[vector_db_id].query_chunks(query, params)
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
|
||||||
self,
|
|
||||||
name: str | None = None,
|
|
||||||
file_ids: list[str] | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
embedding_model: str | None = None,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Creates a vector store."""
|
|
||||||
# store and vector_db have the same id
|
|
||||||
store_id = name or str(uuid.uuid4())
|
|
||||||
created_at = int(time.time())
|
|
||||||
|
|
||||||
if provider_id is None:
|
|
||||||
raise ValueError("Provider ID is required")
|
|
||||||
|
|
||||||
if embedding_model is None:
|
|
||||||
raise ValueError("Embedding model is required")
|
|
||||||
|
|
||||||
# Use provided embedding dimension or default to 384
|
|
||||||
if embedding_dimension is None:
|
|
||||||
raise ValueError("Embedding dimension is required")
|
|
||||||
|
|
||||||
provider_vector_db_id = provider_vector_db_id or store_id
|
|
||||||
vector_db = VectorDB(
|
|
||||||
identifier=store_id,
|
|
||||||
embedding_dimension=embedding_dimension,
|
|
||||||
embedding_model=embedding_model,
|
|
||||||
provider_id=provider_id,
|
|
||||||
provider_resource_id=provider_vector_db_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Register the vector DB
|
|
||||||
await self.register_vector_db(vector_db)
|
|
||||||
|
|
||||||
# Create OpenAI vector store metadata
|
|
||||||
store_info = {
|
|
||||||
"id": store_id,
|
|
||||||
"object": "vector_store",
|
|
||||||
"created_at": created_at,
|
|
||||||
"name": store_id,
|
|
||||||
"usage_bytes": 0,
|
|
||||||
"file_counts": {},
|
|
||||||
"status": "completed",
|
|
||||||
"expires_after": expires_after,
|
|
||||||
"expires_at": None,
|
|
||||||
"last_active_at": created_at,
|
|
||||||
"file_ids": file_ids or [],
|
|
||||||
"chunking_strategy": chunking_strategy,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Add provider information to metadata if provided
|
|
||||||
metadata = metadata or {}
|
|
||||||
if provider_id:
|
|
||||||
metadata["provider_id"] = provider_id
|
|
||||||
if provider_vector_db_id:
|
|
||||||
metadata["provider_vector_db_id"] = provider_vector_db_id
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Store in SQLite database
|
|
||||||
def _store_openai_vector_store():
|
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
|
||||||
cur = connection.cursor()
|
|
||||||
try:
|
|
||||||
cur.execute(
|
|
||||||
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
|
|
||||||
(store_id, json.dumps(store_info)),
|
|
||||||
)
|
|
||||||
connection.commit()
|
|
||||||
finally:
|
|
||||||
cur.close()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
await asyncio.to_thread(_store_openai_vector_store)
|
|
||||||
|
|
||||||
# Store in memory cache
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
return VectorStoreObject(
|
|
||||||
id=store_id,
|
|
||||||
created_at=created_at,
|
|
||||||
name=store_id,
|
|
||||||
usage_bytes=0,
|
|
||||||
file_counts={},
|
|
||||||
status="completed",
|
|
||||||
expires_after=expires_after,
|
|
||||||
expires_at=None,
|
|
||||||
last_active_at=created_at,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
|
||||||
self,
|
|
||||||
limit: int = 20,
|
|
||||||
order: str = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
) -> VectorStoreListResponse:
|
|
||||||
"""Returns a list of vector stores."""
|
|
||||||
# Get all vector stores
|
|
||||||
all_stores = list(self.openai_vector_stores.values())
|
|
||||||
|
|
||||||
# Sort by created_at
|
|
||||||
reverse_order = order == "desc"
|
|
||||||
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
|
|
||||||
|
|
||||||
# Apply cursor-based pagination
|
|
||||||
if after:
|
|
||||||
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
|
|
||||||
if after_index >= 0:
|
|
||||||
all_stores = all_stores[after_index + 1 :]
|
|
||||||
|
|
||||||
if before:
|
|
||||||
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
|
||||||
all_stores = all_stores[:before_index]
|
|
||||||
|
|
||||||
# Apply limit
|
|
||||||
limited_stores = all_stores[:limit]
|
|
||||||
# Convert to VectorStoreObject instances
|
|
||||||
data = [VectorStoreObject(**store) for store in limited_stores]
|
|
||||||
|
|
||||||
# Determine pagination info
|
|
||||||
has_more = len(all_stores) > limit
|
|
||||||
first_id = data[0].id if data else None
|
|
||||||
last_id = data[-1].id if data else None
|
|
||||||
|
|
||||||
return VectorStoreListResponse(
|
|
||||||
data=data,
|
|
||||||
has_more=has_more,
|
|
||||||
first_id=first_id,
|
|
||||||
last_id=last_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Retrieves a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id]
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
name: str | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
"""Modifies a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
|
||||||
|
|
||||||
# Update fields if provided
|
|
||||||
if name is not None:
|
|
||||||
store_info["name"] = name
|
|
||||||
if expires_after is not None:
|
|
||||||
store_info["expires_after"] = expires_after
|
|
||||||
if metadata is not None:
|
|
||||||
store_info["metadata"] = metadata
|
|
||||||
|
|
||||||
# Update last_active_at
|
|
||||||
store_info["last_active_at"] = int(time.time())
|
|
||||||
|
|
||||||
# Save to SQLite database
|
|
||||||
def _update_openai_vector_store():
|
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
|
||||||
cur = connection.cursor()
|
|
||||||
try:
|
|
||||||
cur.execute(
|
|
||||||
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
|
|
||||||
(json.dumps(store_info), vector_store_id),
|
|
||||||
)
|
|
||||||
connection.commit()
|
|
||||||
finally:
|
|
||||||
cur.close()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
await asyncio.to_thread(_update_openai_vector_store)
|
|
||||||
|
|
||||||
# Update in-memory cache
|
|
||||||
self.openai_vector_stores[vector_store_id] = store_info
|
|
||||||
|
|
||||||
return VectorStoreObject(**store_info)
|
|
||||||
|
|
||||||
async def openai_delete_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreDeleteResponse:
|
|
||||||
"""Delete a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
# Delete from SQLite database
|
|
||||||
def _delete_openai_vector_store():
|
|
||||||
connection = _create_sqlite_connection(self.config.db_path)
|
|
||||||
cur = connection.cursor()
|
|
||||||
try:
|
|
||||||
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,))
|
|
||||||
connection.commit()
|
|
||||||
finally:
|
|
||||||
cur.close()
|
|
||||||
connection.close()
|
|
||||||
|
|
||||||
await asyncio.to_thread(_delete_openai_vector_store)
|
|
||||||
|
|
||||||
# Delete from in-memory cache
|
|
||||||
del self.openai_vector_stores[vector_store_id]
|
|
||||||
|
|
||||||
# Also delete the underlying vector DB
|
|
||||||
try:
|
|
||||||
await self.unregister_vector_db(vector_store_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
|
||||||
|
|
||||||
return VectorStoreDeleteResponse(
|
|
||||||
id=vector_store_id,
|
|
||||||
deleted=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
query: str | list[str],
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
max_num_results: int = 10,
|
|
||||||
ranking_options: dict[str, Any] | None = None,
|
|
||||||
rewrite_query: bool = False,
|
|
||||||
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
|
||||||
) -> VectorStoreSearchResponse:
|
|
||||||
"""Search for chunks in a vector store."""
|
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
|
||||||
|
|
||||||
if isinstance(query, list):
|
|
||||||
search_query = " ".join(query)
|
|
||||||
else:
|
|
||||||
search_query = query
|
|
||||||
|
|
||||||
try:
|
|
||||||
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
|
||||||
params = {
|
|
||||||
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
|
||||||
"score_threshold": score_threshold,
|
|
||||||
"mode": search_mode,
|
|
||||||
}
|
|
||||||
# TODO: Add support for ranking_options.ranker
|
|
||||||
|
|
||||||
response = await self.query_chunks(
|
|
||||||
vector_db_id=vector_store_id,
|
|
||||||
query=search_query,
|
|
||||||
params=params,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Convert response to OpenAI format
|
|
||||||
data = []
|
|
||||||
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
|
|
||||||
# Apply score based filtering
|
|
||||||
if score < score_threshold:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Apply filters if provided
|
|
||||||
if filters:
|
|
||||||
# Simple metadata filtering
|
|
||||||
if not self._matches_filters(chunk.metadata, filters):
|
|
||||||
continue
|
|
||||||
|
|
||||||
chunk_data = {
|
|
||||||
"id": f"chunk_{i}",
|
|
||||||
"object": "vector_store.search_result",
|
|
||||||
"score": score,
|
|
||||||
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
|
|
||||||
"metadata": chunk.metadata,
|
|
||||||
}
|
|
||||||
data.append(chunk_data)
|
|
||||||
if len(data) >= max_num_results:
|
|
||||||
break
|
|
||||||
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=data,
|
|
||||||
has_more=False, # For simplicity, we don't implement pagination here
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
|
||||||
# Return empty results on error
|
|
||||||
return VectorStoreSearchResponse(
|
|
||||||
search_query=search_query,
|
|
||||||
data=[],
|
|
||||||
has_more=False,
|
|
||||||
next_page=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
|
||||||
"""Check if metadata matches the provided filters."""
|
|
||||||
for key, value in filters.items():
|
|
||||||
if key not in metadata:
|
|
||||||
return False
|
|
||||||
if metadata[key] != value:
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||||
|
|
|
@ -189,7 +189,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
|
|
@ -187,7 +187,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
|
|
@ -189,7 +189,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
async def openai_create_vector_store(
|
||||||
self,
|
self,
|
||||||
name: str | None = None,
|
name: str,
|
||||||
file_ids: list[str] | None = None,
|
file_ids: list[str] | None = None,
|
||||||
expires_after: dict[str, Any] | None = None,
|
expires_after: dict[str, Any] | None = None,
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
|
354
llama_stack/providers/utils/memory/openai_vector_store_mixin.py
Normal file
354
llama_stack/providers/utils/memory/openai_vector_store_mixin.py
Normal file
|
@ -0,0 +1,354 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
|
from llama_stack.apis.vector_io import (
|
||||||
|
QueryChunksResponse,
|
||||||
|
VectorStoreDeleteResponse,
|
||||||
|
VectorStoreListResponse,
|
||||||
|
VectorStoreObject,
|
||||||
|
VectorStoreSearchResponse,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Constants for OpenAI vector stores
|
||||||
|
CHUNK_MULTIPLIER = 5
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIVectorStoreMixin(ABC):
|
||||||
|
"""
|
||||||
|
Mixin class that provides common OpenAI Vector Store API implementation.
|
||||||
|
Providers need to implement the abstract storage methods and maintain
|
||||||
|
an openai_vector_stores in-memory cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# These should be provided by the implementing class
|
||||||
|
openai_vector_stores: dict[str, dict[str, Any]]
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Save vector store metadata to persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Load all vector store metadata from persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
|
"""Update vector store metadata in persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
|
"""Delete vector store metadata from persistent storage."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
|
"""Register a vector database (provider-specific implementation)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
"""Unregister a vector database (provider-specific implementation)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def query_chunks(
|
||||||
|
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""Query chunks from a vector database (provider-specific implementation)."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def openai_create_vector_store(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
file_ids: list[str] | None = None,
|
||||||
|
expires_after: dict[str, Any] | None = None,
|
||||||
|
chunking_strategy: dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
embedding_model: str | None = None,
|
||||||
|
embedding_dimension: int | None = 384,
|
||||||
|
provider_id: str | None = None,
|
||||||
|
provider_vector_db_id: str | None = None,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
"""Creates a vector store."""
|
||||||
|
print("IN OPENAI VECTOR STORE MIXIN, openai_create_vector_store")
|
||||||
|
# store and vector_db have the same id
|
||||||
|
store_id = name or str(uuid.uuid4())
|
||||||
|
created_at = int(time.time())
|
||||||
|
|
||||||
|
if provider_id is None:
|
||||||
|
raise ValueError("Provider ID is required")
|
||||||
|
|
||||||
|
if embedding_model is None:
|
||||||
|
raise ValueError("Embedding model is required")
|
||||||
|
|
||||||
|
# Use provided embedding dimension or default to 384
|
||||||
|
if embedding_dimension is None:
|
||||||
|
raise ValueError("Embedding dimension is required")
|
||||||
|
|
||||||
|
provider_vector_db_id = provider_vector_db_id or store_id
|
||||||
|
vector_db = VectorDB(
|
||||||
|
identifier=store_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_resource_id=provider_vector_db_id,
|
||||||
|
)
|
||||||
|
from rich.pretty import pprint
|
||||||
|
|
||||||
|
print("VECTOR DB")
|
||||||
|
pprint(vector_db)
|
||||||
|
|
||||||
|
# Register the vector DB
|
||||||
|
await self.register_vector_db(vector_db)
|
||||||
|
|
||||||
|
# Create OpenAI vector store metadata
|
||||||
|
store_info = {
|
||||||
|
"id": store_id,
|
||||||
|
"object": "vector_store",
|
||||||
|
"created_at": created_at,
|
||||||
|
"name": store_id,
|
||||||
|
"usage_bytes": 0,
|
||||||
|
"file_counts": {},
|
||||||
|
"status": "completed",
|
||||||
|
"expires_after": expires_after,
|
||||||
|
"expires_at": None,
|
||||||
|
"last_active_at": created_at,
|
||||||
|
"file_ids": file_ids or [],
|
||||||
|
"chunking_strategy": chunking_strategy,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add provider information to metadata if provided
|
||||||
|
metadata = metadata or {}
|
||||||
|
if provider_id:
|
||||||
|
metadata["provider_id"] = provider_id
|
||||||
|
if provider_vector_db_id:
|
||||||
|
metadata["provider_vector_db_id"] = provider_vector_db_id
|
||||||
|
store_info["metadata"] = metadata
|
||||||
|
|
||||||
|
# Save to persistent storage (provider-specific)
|
||||||
|
await self._save_openai_vector_store(store_id, store_info)
|
||||||
|
|
||||||
|
# Store in memory cache
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
|
return VectorStoreObject(
|
||||||
|
id=store_id,
|
||||||
|
created_at=created_at,
|
||||||
|
name=store_id,
|
||||||
|
usage_bytes=0,
|
||||||
|
file_counts={},
|
||||||
|
status="completed",
|
||||||
|
expires_after=expires_after,
|
||||||
|
expires_at=None,
|
||||||
|
last_active_at=created_at,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_list_vector_stores(
|
||||||
|
self,
|
||||||
|
limit: int = 20,
|
||||||
|
order: str = "desc",
|
||||||
|
after: str | None = None,
|
||||||
|
before: str | None = None,
|
||||||
|
) -> VectorStoreListResponse:
|
||||||
|
"""Returns a list of vector stores."""
|
||||||
|
# Get all vector stores
|
||||||
|
all_stores = list(self.openai_vector_stores.values())
|
||||||
|
|
||||||
|
# Sort by created_at
|
||||||
|
reverse_order = order == "desc"
|
||||||
|
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
|
||||||
|
|
||||||
|
# Apply cursor-based pagination
|
||||||
|
if after:
|
||||||
|
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
|
||||||
|
if after_index >= 0:
|
||||||
|
all_stores = all_stores[after_index + 1 :]
|
||||||
|
|
||||||
|
if before:
|
||||||
|
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
|
||||||
|
all_stores = all_stores[:before_index]
|
||||||
|
|
||||||
|
# Apply limit
|
||||||
|
limited_stores = all_stores[:limit]
|
||||||
|
# Convert to VectorStoreObject instances
|
||||||
|
data = [VectorStoreObject(**store) for store in limited_stores]
|
||||||
|
|
||||||
|
# Determine pagination info
|
||||||
|
has_more = len(all_stores) > limit
|
||||||
|
first_id = data[0].id if data else None
|
||||||
|
last_id = data[-1].id if data else None
|
||||||
|
|
||||||
|
return VectorStoreListResponse(
|
||||||
|
data=data,
|
||||||
|
has_more=has_more,
|
||||||
|
first_id=first_id,
|
||||||
|
last_id=last_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
"""Retrieves a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
store_info = self.openai_vector_stores[vector_store_id]
|
||||||
|
return VectorStoreObject(**store_info)
|
||||||
|
|
||||||
|
async def openai_update_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
expires_after: dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> VectorStoreObject:
|
||||||
|
"""Modifies a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||||
|
|
||||||
|
# Update fields if provided
|
||||||
|
if name is not None:
|
||||||
|
store_info["name"] = name
|
||||||
|
if expires_after is not None:
|
||||||
|
store_info["expires_after"] = expires_after
|
||||||
|
if metadata is not None:
|
||||||
|
store_info["metadata"] = metadata
|
||||||
|
|
||||||
|
# Update last_active_at
|
||||||
|
store_info["last_active_at"] = int(time.time())
|
||||||
|
|
||||||
|
# Save to persistent storage (provider-specific)
|
||||||
|
await self._update_openai_vector_store(vector_store_id, store_info)
|
||||||
|
|
||||||
|
# Update in-memory cache
|
||||||
|
self.openai_vector_stores[vector_store_id] = store_info
|
||||||
|
|
||||||
|
return VectorStoreObject(**store_info)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
) -> VectorStoreDeleteResponse:
|
||||||
|
"""Delete a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
# Delete from persistent storage (provider-specific)
|
||||||
|
await self._delete_openai_vector_store_from_storage(vector_store_id)
|
||||||
|
|
||||||
|
# Delete from in-memory cache
|
||||||
|
del self.openai_vector_stores[vector_store_id]
|
||||||
|
|
||||||
|
# Also delete the underlying vector DB
|
||||||
|
try:
|
||||||
|
await self.unregister_vector_db(vector_store_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
|
||||||
|
|
||||||
|
return VectorStoreDeleteResponse(
|
||||||
|
id=vector_store_id,
|
||||||
|
deleted=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def openai_search_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: str | list[str],
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
max_num_results: int = 10,
|
||||||
|
ranking_options: dict[str, Any] | None = None,
|
||||||
|
rewrite_query: bool = False,
|
||||||
|
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
|
||||||
|
) -> VectorStoreSearchResponse:
|
||||||
|
"""Search for chunks in a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
if isinstance(query, list):
|
||||||
|
search_query = " ".join(query)
|
||||||
|
else:
|
||||||
|
search_query = query
|
||||||
|
|
||||||
|
try:
|
||||||
|
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
|
||||||
|
params = {
|
||||||
|
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
"mode": search_mode,
|
||||||
|
}
|
||||||
|
# TODO: Add support for ranking_options.ranker
|
||||||
|
|
||||||
|
response = await self.query_chunks(
|
||||||
|
vector_db_id=vector_store_id,
|
||||||
|
query=search_query,
|
||||||
|
params=params,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert response to OpenAI format
|
||||||
|
data = []
|
||||||
|
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
|
||||||
|
# Apply score based filtering
|
||||||
|
if score < score_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Apply filters if provided
|
||||||
|
if filters:
|
||||||
|
# Simple metadata filtering
|
||||||
|
if not self._matches_filters(chunk.metadata, filters):
|
||||||
|
continue
|
||||||
|
|
||||||
|
chunk_data = {
|
||||||
|
"id": f"chunk_{i}",
|
||||||
|
"object": "vector_store.search_result",
|
||||||
|
"score": score,
|
||||||
|
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
|
||||||
|
"metadata": chunk.metadata,
|
||||||
|
}
|
||||||
|
data.append(chunk_data)
|
||||||
|
if len(data) >= max_num_results:
|
||||||
|
break
|
||||||
|
|
||||||
|
return VectorStoreSearchResponse(
|
||||||
|
search_query=search_query,
|
||||||
|
data=data,
|
||||||
|
has_more=False, # For simplicity, we don't implement pagination here
|
||||||
|
next_page=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error searching vector store {vector_store_id}: {e}")
|
||||||
|
# Return empty results on error
|
||||||
|
return VectorStoreSearchResponse(
|
||||||
|
search_query=search_query,
|
||||||
|
data=[],
|
||||||
|
has_more=False,
|
||||||
|
next_page=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
|
||||||
|
"""Check if metadata matches the provided filters."""
|
||||||
|
for key, value in filters.items():
|
||||||
|
if key not in metadata:
|
||||||
|
return False
|
||||||
|
if metadata[key] != value:
|
||||||
|
return False
|
||||||
|
return True
|
Loading…
Add table
Add a link
Reference in a new issue