feat: Add OpenAI compat /v1/vector_store APIs (#2423)
Some checks failed
Integration Tests / test-matrix (http, 3.10, inference) (push) Failing after 28s
Integration Tests / test-matrix (http, 3.11, agents) (push) Failing after 26s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 32s
Integration Tests / test-matrix (http, 3.11, scoring) (push) Failing after 24s
Integration Tests / test-matrix (http, 3.10, inspect) (push) Failing after 32s
Integration Tests / test-matrix (http, 3.11, datasets) (push) Failing after 30s
Integration Tests / test-matrix (http, 3.11, tool_runtime) (push) Failing after 28s
Integration Tests / test-matrix (http, 3.12, inference) (push) Failing after 28s
Integration Tests / test-matrix (http, 3.12, providers) (push) Failing after 26s
Integration Tests / test-matrix (http, 3.12, tool_runtime) (push) Failing after 24s
Integration Tests / test-matrix (http, 3.10, scoring) (push) Failing after 32s
Integration Tests / test-matrix (http, 3.11, inference) (push) Failing after 39s
Integration Tests / test-matrix (http, 3.12, scoring) (push) Failing after 35s
Integration Tests / test-matrix (http, 3.10, datasets) (push) Failing after 47s
Integration Tests / test-matrix (http, 3.12, post_training) (push) Failing after 22s
Integration Tests / test-matrix (http, 3.11, inspect) (push) Failing after 42s
Integration Tests / test-matrix (http, 3.11, post_training) (push) Failing after 31s
Integration Tests / test-matrix (http, 3.11, providers) (push) Failing after 44s
Integration Tests / test-matrix (http, 3.12, datasets) (push) Failing after 40s
Integration Tests / test-matrix (http, 3.12, inspect) (push) Failing after 42s
Integration Tests / test-matrix (http, 3.10, providers) (push) Failing after 44s
Integration Tests / test-matrix (http, 3.10, agents) (push) Failing after 33s
Integration Tests / test-matrix (http, 3.12, agents) (push) Failing after 37s
Integration Tests / test-matrix (library, 3.10, inspect) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, inference) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.10, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.10, providers) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.10, scoring) (push) Failing after 11s
Integration Tests / test-matrix (http, 3.10, post_training) (push) Failing after 41s
Integration Tests / test-matrix (library, 3.10, datasets) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.10, post_training) (push) Failing after 13s
Integration Tests / test-matrix (http, 3.10, tool_runtime) (push) Failing after 46s
Integration Tests / test-matrix (library, 3.10, tool_runtime) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, inference) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.11, post_training) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, datasets) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.11, inspect) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.11, providers) (push) Failing after 12s
Integration Tests / test-matrix (library, 3.11, tool_runtime) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.11, scoring) (push) Failing after 14s
Integration Tests / test-matrix (library, 3.12, agents) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, inference) (push) Failing after 7s
Integration Tests / test-matrix (library, 3.12, datasets) (push) Failing after 11s
Integration Tests / test-matrix (library, 3.12, inspect) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, scoring) (push) Failing after 5s
Integration Tests / test-matrix (library, 3.12, post_training) (push) Failing after 16s
Test External Providers / test-external-providers (venv) (push) Failing after 10s
Integration Tests / test-matrix (library, 3.12, providers) (push) Failing after 15s
Integration Tests / test-matrix (library, 3.12, tool_runtime) (push) Failing after 13s
Update ReadTheDocs / update-readthedocs (push) Failing after 8s
Unit Tests / unit-tests (3.13) (push) Failing after 11s
Unit Tests / unit-tests (3.12) (push) Failing after 1m31s
Unit Tests / unit-tests (3.11) (push) Failing after 1m33s
Unit Tests / unit-tests (3.10) (push) Failing after 1m35s
Pre-commit / pre-commit (push) Failing after 3h13m41s

Adding OpenAI compat `/v1/vector-store` apis. 
This PR implements the `faiss` provider with followup PRs coming up for
other providers.

Added routes to create, update, delete, list vector stores. 
Also added route to search a vector store

Inserting into vector stores is missing and will be a follow up diff. 

### Test Plan 
- Added new integration test for testing the faiss provider 
```
pytest -sv --stack-config http://localhost:8321 tests/integration/vector_io/test_openai_vector_stores.py --embedding-model all-MiniLM-L6-v2
```
This commit is contained in:
Hardik Shah 2025-06-10 13:07:39 -07:00 committed by GitHub
parent ee57e58f29
commit 5ac43268e8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 2930 additions and 16 deletions

View file

@ -37,6 +37,85 @@ class QueryChunksResponse(BaseModel):
scores: list[float]
@json_schema_type
class VectorStoreObject(BaseModel):
"""OpenAI Vector Store object."""
id: str
object: str = "vector_store"
created_at: int
name: str | None = None
usage_bytes: int = 0
file_counts: dict[str, int] = Field(default_factory=dict)
status: str = "completed"
expires_after: dict[str, Any] | None = None
expires_at: int | None = None
last_active_at: int | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class VectorStoreCreateRequest(BaseModel):
"""Request to create a vector store."""
name: str | None = None
file_ids: list[str] = Field(default_factory=list)
expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
@json_schema_type
class VectorStoreModifyRequest(BaseModel):
"""Request to modify a vector store."""
name: str | None = None
expires_after: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None
@json_schema_type
class VectorStoreListResponse(BaseModel):
"""Response from listing vector stores."""
object: str = "list"
data: list[VectorStoreObject]
first_id: str | None = None
last_id: str | None = None
has_more: bool = False
@json_schema_type
class VectorStoreSearchRequest(BaseModel):
"""Request to search a vector store."""
query: str | list[str]
filters: dict[str, Any] | None = None
max_num_results: int = 10
ranking_options: dict[str, Any] | None = None
rewrite_query: bool = False
@json_schema_type
class VectorStoreSearchResponse(BaseModel):
"""Response from searching a vector store."""
object: str = "vector_store.search_results.page"
search_query: str
data: list[dict[str, Any]]
has_more: bool = False
next_page: str | None = None
@json_schema_type
class VectorStoreDeleteResponse(BaseModel):
"""Response from deleting a vector store."""
id: str
object: str = "vector_store.deleted"
deleted: bool = True
class VectorDBStore(Protocol):
def get_vector_db(self, vector_db_id: str) -> VectorDB | None: ...
@ -81,3 +160,116 @@ class VectorIO(Protocol):
:returns: A QueryChunksResponse.
"""
...
# OpenAI Vector Stores API endpoints
@webmethod(route="/openai/v1/vector_stores", method="POST")
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store.
:param name: A name for the vector store.
:param file_ids: A list of File IDs that the vector store should use. Useful for tools like `file_search` that can access files.
:param expires_after: The expiration policy for a vector store.
:param chunking_strategy: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
:param embedding_model: The embedding model to use for this vector store.
:param embedding_dimension: The dimension of the embedding vectors (default: 384).
:param provider_id: The ID of the provider to use for this vector store.
:param provider_vector_db_id: The provider-specific vector database ID.
:returns: A VectorStoreObject representing the created vector store.
"""
...
@webmethod(route="/openai/v1/vector_stores", method="GET")
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores.
:param limit: A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.
:param order: Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.
:param after: A cursor for use in pagination. `after` is an object ID that defines your place in the list.
:param before: A cursor for use in pagination. `before` is an object ID that defines your place in the list.
:returns: A VectorStoreListResponse containing the list of vector stores.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="GET")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store.
:param vector_store_id: The ID of the vector store to retrieve.
:returns: A VectorStoreObject representing the vector store.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="POST")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Updates a vector store.
:param vector_store_id: The ID of the vector store to update.
:param name: The name of the vector store.
:param expires_after: The expiration policy for a vector store.
:param metadata: Set of 16 key-value pairs that can be attached to an object.
:returns: A VectorStoreObject representing the updated vector store.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}", method="DELETE")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store.
:param vector_store_id: The ID of the vector store to delete.
:returns: A VectorStoreDeleteResponse indicating the deletion status.
"""
...
@webmethod(route="/openai/v1/vector_stores/{vector_store_id}/search", method="POST")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
) -> VectorStoreSearchResponse:
"""Search for chunks in a vector store.
Searches a vector store for relevant chunks based on a query and optional file attribute filters.
:param vector_store_id: The ID of the vector store to search.
:param query: The query string or array for performing the search.
:param filters: Filters based on file attributes to narrow the search results.
:param max_num_results: Maximum number of results to return (1 to 50 inclusive, default 10).
:param ranking_options: Ranking options for fine-tuning the search results.
:param rewrite_query: Whether to rewrite the natural language query for vector search (default false)
:returns: A VectorStoreSearchResponse containing the search results.
"""
...

View file

@ -9,7 +9,16 @@ from typing import Any
from llama_stack.apis.common.content_types import (
InterleavedContent,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.models import ModelType
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
@ -34,6 +43,31 @@ class VectorIORouter(VectorIO):
logger.debug("VectorIORouter.shutdown")
pass
async def _get_first_embedding_model(self) -> tuple[str, int] | None:
"""Get the first available embedding model identifier."""
try:
# Get all models from the routing table
all_models = await self.routing_table.get_all_with_type("model")
# Filter for embedding models
embedding_models = [
model
for model in all_models
if hasattr(model, "model_type") and model.model_type == ModelType.embedding
]
if embedding_models:
dimension = embedding_models[0].metadata.get("embedding_dimension", None)
if dimension is None:
raise ValueError(f"Embedding model {embedding_models[0].identifier} has no embedding dimension")
return embedding_models[0].identifier, dimension
else:
logger.warning("No embedding models found in the routing table")
return None
except Exception as e:
logger.error(f"Error getting embedding models: {e}")
return None
async def register_vector_db(
self,
vector_db_id: str,
@ -70,3 +104,153 @@ class VectorIORouter(VectorIO):
) -> QueryChunksResponse:
logger.debug(f"VectorIORouter.query_chunks: {vector_db_id}")
return await self.routing_table.get_provider_impl(vector_db_id).query_chunks(vector_db_id, query, params)
# OpenAI Vector Stores API endpoints
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = None,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one
if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None:
raise ValueError("No embedding model provided and no embedding models available in the system")
embedding_model, embedding_dimension = embedding_model_info
logger.info(f"No embedding model specified, using first available: {embedding_model}")
vector_db_id = name
registered_vector_db = await self.routing_table.register_vector_db(
vector_db_id,
embedding_model,
embedding_dimension,
provider_id,
provider_vector_db_id,
)
return await self.routing_table.get_provider_impl(registered_vector_db.identifier).openai_create_vector_store(
vector_db_id,
file_ids=file_ids,
expires_after=expires_after,
chunking_strategy=chunking_strategy,
metadata=metadata,
embedding_model=embedding_model,
embedding_dimension=embedding_dimension,
provider_id=registered_vector_db.provider_id,
provider_vector_db_id=registered_vector_db.provider_resource_id,
)
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
logger.debug(f"VectorIORouter.openai_list_vector_stores: limit={limit}")
# Route to default provider for now - could aggregate from all providers in the future
# call retrieve on each vector dbs to get list of vector stores
vector_dbs = await self.routing_table.get_all_with_type("vector_db")
all_stores = []
for vector_db in vector_dbs:
vector_store = await self.routing_table.get_provider_impl(
vector_db.identifier
).openai_retrieve_vector_store(vector_db.identifier)
all_stores.append(vector_store)
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x.created_at, reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store.id == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store.id == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = limited_stores[0].id if limited_stores else None
last_id = limited_stores[-1].id if limited_stores else None
return VectorStoreListResponse(
data=limited_stores,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_retrieve_vector_store(vector_store_id)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_update_vector_store(
vector_store_id=vector_store_id,
name=name,
expires_after=expires_after,
metadata=metadata,
)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
result = await provider.openai_delete_vector_store(vector_store_id)
# drop from registry
await self.routing_table.unregister_vector_db(vector_store_id)
return result
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
) -> VectorStoreSearchResponse:
logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}")
# Route based on vector store ID
provider = self.routing_table.get_provider_impl(vector_store_id)
return await provider.openai_search_vector_store(
vector_store_id=vector_store_id,
query=query,
filters=filters,
max_num_results=max_num_results,
ranking_options=ranking_options,
rewrite_query=rewrite_query,
)

View file

@ -9,16 +9,26 @@ import base64
import io
import json
import logging
from typing import Any
import time
import uuid
from typing import Any, Literal
import faiss
import numpy as np
from numpy.typing import NDArray
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
@ -34,6 +44,11 @@ logger = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
# In faiss, since we do
CHUNK_MULTIPLIER = 5
class FaissIndex(EmbeddingIndex):
@ -131,6 +146,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.inference_api = inference_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.kvstore: KVStore | None = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
@ -148,6 +164,15 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
self.cache[vector_db.identifier] = index
# Load existing OpenAI vector stores
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
self.openai_vector_stores[store_info["id"]] = store_info
async def shutdown(self) -> None:
# Cleanup if needed
pass
@ -208,3 +233,286 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
# OpenAI Vector Stores API endpoints implementation
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
assert self.kvstore is not None
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Store in kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
assert self.kvstore is not None
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
assert self.kvstore is not None
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from kvstore
key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}"
await self.kvstore.delete(key)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
"""Search for chunks in a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
chunk_data = {
"id": f"chunk_{i}",
"object": "vector_store.search_result",
"score": score,
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
"metadata": chunk.metadata,
}
data.append(chunk_data)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponse(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponse(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True

View file

@ -6,11 +6,13 @@
import asyncio
import hashlib
import json
import logging
import sqlite3
import struct
import time
import uuid
from typing import Any
from typing import Any, Literal
import numpy as np
import sqlite_vec
@ -18,7 +20,15 @@ from numpy.typing import NDArray
from llama_stack.apis.inference.inference import Inference
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex
@ -29,6 +39,11 @@ VECTOR_SEARCH = "vector"
KEYWORD_SEARCH = "keyword"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH}
# Constants for OpenAI vector stores (similar to faiss)
VERSION = "v3"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
CHUNK_MULTIPLIER = 5
def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation."""
@ -299,6 +314,7 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.config = config
self.inference_api = inference_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
async def initialize(self) -> None:
def _setup_connection():
@ -313,17 +329,29 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
metadata TEXT
);
""")
# Create a table to persist OpenAI vector stores.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_stores (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
rows = cur.fetchall()
return rows
vector_db_rows = cur.fetchall()
# Load any existing OpenAI vector stores.
cur.execute("SELECT metadata FROM openai_vector_stores")
openai_store_rows = cur.fetchall()
return vector_db_rows, openai_store_rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_setup_connection)
for row in rows:
vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
@ -331,6 +359,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores
for row in openai_store_rows:
store_data = row[0]
store_info = json.loads(store_data)
self.openai_vector_stores[store_info["id"]] = store_info
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
@ -389,6 +423,318 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
"""Creates a vector store."""
# store and vector_db have the same id
store_id = name or str(uuid.uuid4())
created_at = int(time.time())
if provider_id is None:
raise ValueError("Provider ID is required")
if embedding_model is None:
raise ValueError("Embedding model is required")
# Use provided embedding dimension or default to 384
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
provider_vector_db_id = provider_vector_db_id or store_id
vector_db = VectorDB(
identifier=store_id,
embedding_dimension=embedding_dimension,
embedding_model=embedding_model,
provider_id=provider_id,
provider_resource_id=provider_vector_db_id,
)
# Register the vector DB
await self.register_vector_db(vector_db)
# Create OpenAI vector store metadata
store_info = {
"id": store_id,
"object": "vector_store",
"created_at": created_at,
"name": store_id,
"usage_bytes": 0,
"file_counts": {},
"status": "completed",
"expires_after": expires_after,
"expires_at": None,
"last_active_at": created_at,
"file_ids": file_ids or [],
"chunking_strategy": chunking_strategy,
}
# Add provider information to metadata if provided
metadata = metadata or {}
if provider_id:
metadata["provider_id"] = provider_id
if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata
# Store in SQLite database
def _store_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_store_openai_vector_store)
# Store in memory cache
self.openai_vector_stores[store_id] = store_info
return VectorStoreObject(
id=store_id,
created_at=created_at,
name=store_id,
usage_bytes=0,
file_counts={},
status="completed",
expires_after=expires_after,
expires_at=None,
last_active_at=created_at,
metadata=metadata,
)
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
"""Returns a list of vector stores."""
# Get all vector stores
all_stores = list(self.openai_vector_stores.values())
# Sort by created_at
reverse_order = order == "desc"
all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order)
# Apply cursor-based pagination
if after:
after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1)
if after_index >= 0:
all_stores = all_stores[after_index + 1 :]
if before:
before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores))
all_stores = all_stores[:before_index]
# Apply limit
limited_stores = all_stores[:limit]
# Convert to VectorStoreObject instances
data = [VectorStoreObject(**store) for store in limited_stores]
# Determine pagination info
has_more = len(all_stores) > limit
first_id = data[0].id if data else None
last_id = data[-1].id if data else None
return VectorStoreListResponse(
data=data,
has_more=has_more,
first_id=first_id,
last_id=last_id,
)
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
"""Retrieves a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id]
return VectorStoreObject(**store_info)
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
"""Modifies a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
store_info = self.openai_vector_stores[vector_store_id].copy()
# Update fields if provided
if name is not None:
store_info["name"] = name
if expires_after is not None:
store_info["expires_after"] = expires_after
if metadata is not None:
store_info["metadata"] = metadata
# Update last_active_at
store_info["last_active_at"] = int(time.time())
# Save to SQLite database
def _update_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), vector_store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update_openai_vector_store)
# Update in-memory cache
self.openai_vector_stores[vector_store_id] = store_info
return VectorStoreObject(**store_info)
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
"""Delete a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
# Delete from SQLite database
def _delete_openai_vector_store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_openai_vector_store)
# Delete from in-memory cache
del self.openai_vector_stores[vector_store_id]
# Also delete the underlying vector DB
try:
await self.unregister_vector_db(vector_store_id)
except Exception as e:
logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}")
return VectorStoreDeleteResponse(
id=vector_store_id,
deleted=True,
)
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
"""Search for chunks in a vector store."""
if vector_store_id not in self.openai_vector_stores:
raise ValueError(f"Vector store {vector_store_id} not found")
if isinstance(query, list):
search_query = " ".join(query)
else:
search_query = query
try:
score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0
params = {
"max_chunks": max_num_results * CHUNK_MULTIPLIER,
"score_threshold": score_threshold,
"mode": search_mode,
}
# TODO: Add support for ranking_options.ranker
response = await self.query_chunks(
vector_db_id=vector_store_id,
query=search_query,
params=params,
)
# Convert response to OpenAI format
data = []
for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)):
# Apply score based filtering
if score < score_threshold:
continue
# Apply filters if provided
if filters:
# Simple metadata filtering
if not self._matches_filters(chunk.metadata, filters):
continue
chunk_data = {
"id": f"chunk_{i}",
"object": "vector_store.search_result",
"score": score,
"content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content),
"metadata": chunk.metadata,
}
data.append(chunk_data)
if len(data) >= max_num_results:
break
return VectorStoreSearchResponse(
search_query=search_query,
data=data,
has_more=False, # For simplicity, we don't implement pagination here
next_page=None,
)
except Exception as e:
logger.error(f"Error searching vector store {vector_store_id}: {e}")
# Return empty results on error
return VectorStoreSearchResponse(
search_query=search_query,
data=[],
has_more=False,
next_page=None,
)
def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool:
"""Check if metadata matches the provided filters."""
for key, value in filters.items():
if key not in metadata:
return False
if metadata[key] != value:
return False
return True
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""

View file

@ -6,7 +6,7 @@
import asyncio
import json
import logging
from typing import Any
from typing import Any, Literal
from urllib.parse import urlparse
import chromadb
@ -14,7 +14,15 @@ from numpy.typing import NDArray
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
@ -178,3 +186,59 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
self.cache[vector_db_id] = index
return index
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")

View file

@ -9,14 +9,22 @@ import hashlib
import logging
import os
import uuid
from typing import Any
from typing import Any, Literal
from numpy.typing import NDArray
from pymilvus import MilvusClient
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
@ -177,6 +185,62 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
return await index.query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
"""Generate a unique chunk ID using a hash of document ID and chunk text."""

View file

@ -6,7 +6,7 @@
import logging
import uuid
from typing import Any
from typing import Any, Literal
from numpy.typing import NDArray
from qdrant_client import AsyncQdrantClient, models
@ -14,7 +14,15 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
VectorIO,
VectorStoreDeleteResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponse,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.memory.vector_store import (
@ -178,3 +186,59 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
async def openai_create_vector_store(
self,
name: str | None = None,
file_ids: list[str] | None = None,
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
provider_vector_db_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_list_vector_stores(
self,
limit: int = 20,
order: str = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_retrieve_vector_store(
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_update_vector_store(
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_delete_vector_store(
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int = 10,
ranking_options: dict[str, Any] | None = None,
rewrite_query: bool = False,
search_mode: Literal["keyword", "vector", "hybrid"] = "vector",
) -> VectorStoreSearchResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")