Merge branch 'main' into feat/add-url-to-paginated-response

This commit is contained in:
Rohan Awhad 2025-06-13 13:07:45 -04:00 committed by GitHub
commit b5047db685
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
24 changed files with 911 additions and 856 deletions

View file

@ -8,7 +8,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Protocol, runtime_checkable
from typing import Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, Field
@ -96,13 +96,30 @@ class VectorStoreSearchRequest(BaseModel):
rewrite_query: bool = False
@json_schema_type
class VectorStoreContent(BaseModel):
type: Literal["text"]
text: str
@json_schema_type
class VectorStoreSearchResponse(BaseModel):
"""Response from searching a vector store."""
file_id: str
filename: str
score: float
attributes: dict[str, str | float | bool] | None = None
content: list[VectorStoreContent]
@json_schema_type
class VectorStoreSearchResponsePage(BaseModel):
"""Response from searching a vector store."""
object: str = "vector_store.search_results.page"
search_query: str
data: list[dict[str, Any]]
data: list[VectorStoreSearchResponse]
has_more: bool = False
next_page: str | None = None
@ -165,7 +182,7 @@ class VectorIO(Protocol):
@webmethod(route="/openai/v1/vector_stores", method="POST")
async def openai_create_vector_store(
self,
name: str | None = None,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
@ -193,8 +210,8 @@ class VectorIO(Protocol):
@webmethod(route="/openai/v1/vector_stores", method="GET")
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
@ -256,10 +273,10 @@ class VectorIO(Protocol):
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
) -> VectorStoreSearchResponse:
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store.
Searches a vector store for relevant chunks based on a query and optional file attribute filters.

View file

@ -180,6 +180,7 @@ def get_provider_registry(
if provider_type_key in ret[api]:
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
ret[api][provider_type_key] = spec
logger.info(f"Successfully loaded external provider {provider_type_key}")
except yaml.YAMLError as yaml_err:
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
raise yaml_err

View file

@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}")
missing_methods.append((name, "signature_mismatch"))
else:
# Check if the method is actually implemented in the class
method_owner = next((cls for cls in mro if name in cls.__dict__), None)
if method_owner is None or method_owner.__name__ == protocol.__name__:
# Check if the method has a concrete implementation (not just a protocol stub)
# Find all classes in MRO that define this method
method_owners = [cls for cls in mro if name in cls.__dict__]
# Allow methods from mixins/parents, only reject if ONLY the protocol defines it
if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__:
# Only reject if the method is ONLY defined in the protocol itself (abstract stub)
missing_methods.append((name, "not_actually_implemented"))
if missing_methods:

View file

@ -163,6 +163,9 @@ class InferenceRouter(Inference):
messages: list[Message] | InterleavedContent,
tool_prompt_format: ToolPromptFormat | None = None,
) -> int | None:
if not hasattr(self, "formatter") or self.formatter is None:
return None
if isinstance(messages, list):
encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format)
else:

View file

@ -17,7 +17,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -108,7 +108,7 @@ class VectorIORouter(VectorIO):
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
name: str | None = None,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
@ -151,8 +151,8 @@ class VectorIORouter(VectorIO):
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
@ -239,10 +239,10 @@ class VectorIORouter(VectorIO):
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
) -> VectorStoreSearchResponse:
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)

View file

@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
uri: str
token: str | None = Field(default=None, description="token to authorise access to jwks")
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
headers = {}
if self.config.jwks.token:
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}

View file

@ -115,7 +115,7 @@ def parse_environment_config(env_config: str) -> dict[str, int]:
class CustomRichHandler(RichHandler):
def __init__(self, *args, **kwargs):
kwargs["console"] = Console(width=120)
kwargs["console"] = Console(width=150)
super().__init__(*args, **kwargs)
def emit(self, record):

View file

@ -9,9 +9,7 @@ import base64
import io
import json
import logging
import time
import uuid
from typing import Any, Literal
from typing import Any
import faiss
import numpy as np
@ -24,14 +22,11 @@ from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@ -47,10 +42,6 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
# In faiss, since we do
CHUNK_MULTIPLIER = 5
class FaissIndex(EmbeddingIndex):
def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
self.index = faiss.IndexFlatL2(dimension)
@ -140,7 +131,7 @@ class FaissIndex(EmbeddingIndex):
raise NotImplementedError("Keyword search is not supported in FAISS")
class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None:
self.config = config
self.inference_api = inference_api
@ -164,14 +155,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
self.cache[vector_db.identifier] = index
# Load existing OpenAI vector stores
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
self.openai_vector_stores[store_info["id"]] = store_info
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
@ -234,285 +219,34 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return await index.query_chunks(query, params)
# OpenAI Vector Stores API endpoints implementation
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to kvstore."""
assert self.kvstore is not None
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Store in kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from kvstore."""
assert self.kvstore is not None
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
store_info = self.openai_vector_stores[vector_store_id].copy()
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from kvstore."""
assert self.kvstore is not None
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
"""Search for chunks in a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
chunk_data = {
"id": f"chunk_{i}",
"object": "vector_store.search_result",
"score": score,
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
"metadata": chunk.metadata,
}
data.append(chunk_data)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponse(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponse(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True

View file

@ -10,9 +10,8 @@ import json
import logging
import sqlite3
import struct
import time
import uuid
from typing import Any, Literal
from typing import Any
import numpy as np
import sqlite_vec
@ -24,12 +23,9 @@ from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
logger = logging.getLogger(__name__)
@ -39,11 +35,6 @@ VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
# Constants for OpenAI vector stores (similar to faiss)
VERSION = "v3"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
CHUNK_MULTIPLIER = 5
def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
@ -303,7 +294,7 @@ class SQLiteVecIndex(EmbeddingIndex):
return QueryChunksResponse(chunks=chunks, scores=scores)
class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
"""
A VectorIO implementation using SQLite + sqlite_vec.
This class handles vector database registration (with metadata stored in a table named `vector_dbs`)
@ -340,15 +331,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
vector_db_rows = cur.fetchall()
# Load any existing OpenAI vector stores.
cur.execute("SELECT metadata FROM openai_vector_stores")
openai_store_rows = cur.fetchall()
return vector_db_rows, openai_store_rows
return vector_db_rows
finally:
cur.close()
connection.close()
vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection)
vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
@ -359,11 +347,8 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores
for row in openai_store_rows:
store_data = row[0]
store_info = json.loads(store_data)
self.openai_vector_stores[store_info["id"]] = store_info
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
@ -409,6 +394,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
await asyncio.to_thread(_delete_vector_db_from_registry)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("SELECT metadata FROM openai_vector_stores")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_load)
stores = {}
for row in rows:
store_data = row[0]
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache:
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}")
@ -423,318 +489,6 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Store in SQLite database
def _store_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_store_openai_vector_store)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to SQLite database
def _update_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), vector_store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update_openai_vector_store)
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from SQLite database
def _delete_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_openai_vector_store)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
"""Search for chunks in a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
chunk_data = {
"id": f"chunk_{i}",
"object": "vector_store.search_result",
"score": score,
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
"metadata": chunk.metadata,
}
data.append(chunk_data)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponse(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponse(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""

View file

@ -6,7 +6,7 @@
import asyncio
import json
import logging
from typing import Any, Literal
from typing import Any
from urllib.parse import urlparse
import chromadb
@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
@ -189,7 +189,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store(
self,
name: str | None = None,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
@ -203,8 +203,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
@ -236,9 +236,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -9,7 +9,7 @@ import hashlib
import logging
import os
import uuid
from typing import Any, Literal
from typing import Any
from numpy.typing import NDArray
from pymilvus import MilvusClient
@ -23,7 +23,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
@ -187,7 +187,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store(
self,
name: str | None = None,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
@ -201,8 +201,8 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
@ -234,11 +234,10 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -6,7 +6,7 @@
import logging
import uuid
from typing import Any, Literal
from typing import Any
from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import (
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
@ -189,7 +189,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_create_vector_store(
self,
name: str | None = None,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
@ -203,8 +203,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
@ -236,9 +236,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
rewrite_query: bool | None = False,
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")

View file

@ -76,7 +76,7 @@ class WeaviateIndex(EmbeddingIndex):
continue
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf"))
return QueryChunksResponse(chunks=chunks, scores=scores)

View file

@ -0,0 +1,385 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import time
import uuid
from abc import ABC, abstractmethod
from typing import Any
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorStoreContent,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
logger = logging.getLogger(__name__)
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5
class OpenAIVectorStoreMixin(ABC):
"""
Mixin class that provides common OpenAI Vector Store API implementation.
Providers need to implement the abstract storage methods and maintain
an openai_vector_stores in-memory cache.
"""
# These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]]
@abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
pass
@abstractmethod
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
pass
@abstractmethod
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
pass
@abstractmethod
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
pass
@abstractmethod
async def register_vector_db(self, vector_db: VectorDB) -> None:
"""Register a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def unregister_vector_db(self, vector_db_id: str) -> None:
"""Unregister a vector database (provider-specific implementation)."""
pass
@abstractmethod
async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse:
"""Query chunks from a vector database (provider-specific implementation)."""
pass
async def openai_create_vector_store(
self,
name: str,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Save to persistent storage (provider-specific)
await self._save_openai_vector_store(store_id, store_info)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
limit = limit or 20
order = order or "desc"
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to persistent storage (provider-specific)
await self._update_openai_vector_store(vector_store_id, store_info)
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from persistent storage (provider-specific)
await self._delete_openai_vector_store_from_storage(vector_store_id)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool | None = False,
# search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponsePage:
"""Search for chunks in a vector store."""
# TODO: Add support in the API for this
search_mode = "vector"
max_num_results = max_num_results or 10
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for chunk, score in zip(response.chunks, response.scores, strict=False):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
# content is InterleavedContent
if isinstance(chunk.content, str):
content = [
VectorStoreContent(
type="text",
text=chunk.content,
)
]
elif isinstance(chunk.content, list):
# TODO: Add support for other types of content
content = [
VectorStoreContent(
type="text",
text=item.text,
)
for item in chunk.content
if item.type == "text"
]
else:
if chunk.content.type != "text":
raise ValueError(f"Unsupported content type: {chunk.content.type}")
content = [
VectorStoreContent(
type="text",
text=chunk.content.text,
)
]
response_data_item = VectorStoreSearchResponse(
file_id=chunk.metadata.get("file_id", ""),
filename=chunk.metadata.get("filename", ""),
score=score,
attributes=chunk.metadata,
content=content,
)
data.append(response_data_item)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponsePage(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponsePage(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True