From 4e37b49cdc951ac313ece8fe87e6829a9257b321 Mon Sep 17 00:00:00 2001 From: Rohan Awhad <30470101+RohanAwhad@users.noreply.github.com> Date: Wed, 11 Jun 2025 12:14:41 -0400 Subject: [PATCH 1/9] fix: #1867 InferenceRouter has no attribute formatter (#2422) # What does this PR do? Closes #1867 [Steps to reproduce the bug](https://github.com/meta-llama/llama-stack/issues/1867#issuecomment-2956819381) The change was designed to minimize code changes. Open to option of skipping `metrics` field entirely when `telemetry` is disabled. ## Test Plan 1. Build llama-stack remote-vllm container ```bash llama stack build --template remote-vllm --image-type container ``` 2. Create a small run.yaml ```yaml version: '2' image_name: remote-vllm apis: - inference providers: inference: - provider_id: vllm-inference provider_type: remote::vllm config: url: ${env.VLLM_URL:http://localhost:8000/v1} max_tokens: ${env.VLLM_MAX_TOKENS:4096} api_token: ${env.VLLM_API_TOKEN:fake} tls_verify: ${env.VLLM_TLS_VERIFY:true} metadata_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/registry.db inference_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/remote-vllm}/inference_store.db models: - metadata: {} model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm shields: [] vector_dbs: [] datasets: [] scoring_fns: [] benchmarks: [] server: port: 8321 ``` 3. Run the llama-stack server ```bash export VLLM_URL="http://localhost:8000/v1" export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" llama stack run run.yaml ``` 4. Then perform a curl ```bash curl -X 'POST' \ 'http://localhost:8321/v1/inference/completion' \ -H 'accept: application/json' \ -H 'Content-Type: application/json' \ -d '{ "model_id": "meta-llama/Llama-3.2-3B-Instruct", "content": "string", "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 10, "repetition_penalty": 1, "stop": [ "string" ] }, "stream": false, "logprobs": { "top_k": 0 } }' ``` 5. You should receive a 200 response with metric values set to 0, similar to one below: ``` { "metrics": [ { "metric": "prompt_tokens", "value": 0, "unit": null }, { "metric": "completion_tokens", "value": 0, "unit": null }, { "metric": "total_tokens", "value": 0, "unit": null } ], [...] } ``` Co-authored-by: Rohan Awhad --- llama_stack/distribution/routers/inference.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index 2e111c20a..62d04cdc4 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -163,6 +163,9 @@ class InferenceRouter(Inference): messages: list[Message] | InterleavedContent, tool_prompt_format: ToolPromptFormat | None = None, ) -> int | None: + if not hasattr(self, "formatter") or self.formatter is None: + return None + if isinstance(messages, list): encoded = self.formatter.encode_dialog_prompt(messages, tool_prompt_format) else: From d55100d9b7971aae6414ce2134b8198edb20dc29 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 11 Jun 2025 15:40:57 -0700 Subject: [PATCH 2/9] feat: OpenAIVectorIOMixin for vector_stores common logic (#2427) Extracts common OpenAI vector-store code into its own mixin so that all providers can share the same core logic. This also makes it easy for Llama Stack to support both vector-stores and Llama Stack APIs in the interim so that both share the same underlying vector-dbs. Each provider contains storage specific logic to `create / edit / delete / list` vector dbs while the plumbing logic is standardized in the common code. Ensured that this works well with both faiss and sqllite-vec. ### Test Plan ``` llama stack run starter pytest -sv --stack-config http://localhost:8321 tests/integration/vector_io/test_openai_vector_stores.py --embedding-model all-MiniLM-L6-v2 ``` --- docs/_static/llama-stack-spec.html | 3 + docs/_static/llama-stack-spec.yaml | 2 + llama_stack/apis/vector_io/vector_io.py | 2 +- llama_stack/distribution/resolver.py | 10 +- llama_stack/distribution/routers/vector_io.py | 2 +- .../providers/inline/vector_io/faiss/faiss.py | 316 ++----------- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 422 ++++-------------- .../remote/vector_io/chroma/chroma.py | 2 +- .../remote/vector_io/milvus/milvus.py | 2 +- .../remote/vector_io/qdrant/qdrant.py | 2 +- .../utils/memory/openai_vector_store_mixin.py | 354 +++++++++++++++ 11 files changed, 484 insertions(+), 633 deletions(-) create mode 100644 llama_stack/providers/utils/memory/openai_vector_store_mixin.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index f5958754e..18f4f49b9 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -12583,6 +12583,9 @@ } }, "additionalProperties": false, + "required": [ + "name" + ], "title": "OpenaiCreateVectorStoreRequest" }, "VectorStoreObject": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c302e58fb..4dbcd3ac9 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -8791,6 +8791,8 @@ components: description: >- The provider-specific vector database ID. additionalProperties: false + required: + - name title: OpenaiCreateVectorStoreRequest VectorStoreObject: type: object diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 229d62386..5f54539a4 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -165,7 +165,7 @@ class VectorIO(Protocol): @webmethod(route="/openai/v1/vector_stores", method="POST") async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 6e7bb5edd..e71ff8092 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -394,9 +394,13 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None: logger.error(f"Method {name} incompatible proto: {proto_params} vs. obj: {obj_params}") missing_methods.append((name, "signature_mismatch")) else: - # Check if the method is actually implemented in the class - method_owner = next((cls for cls in mro if name in cls.__dict__), None) - if method_owner is None or method_owner.__name__ == protocol.__name__: + # Check if the method has a concrete implementation (not just a protocol stub) + # Find all classes in MRO that define this method + method_owners = [cls for cls in mro if name in cls.__dict__] + + # Allow methods from mixins/parents, only reject if ONLY the protocol defines it + if len(method_owners) == 1 and method_owners[0].__name__ == protocol.__name__: + # Only reject if the method is ONLY defined in the protocol itself (abstract stub) missing_methods.append((name, "not_actually_implemented")) if missing_methods: diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index 92ddcbcfa..30b19c436 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -108,7 +108,7 @@ class VectorIORouter(VectorIO): # OpenAI Vector Stores API endpoints async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index d0b195cbe..5e9155011 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -9,9 +9,7 @@ import base64 import io import json import logging -import time -import uuid -from typing import Any, Literal +from typing import Any import faiss import numpy as np @@ -24,14 +22,11 @@ from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, VectorIO, - VectorStoreDeleteResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponse, ) from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, @@ -47,10 +42,6 @@ FAISS_INDEX_PREFIX = f"faiss_index:{VERSION}::" OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" -# In faiss, since we do -CHUNK_MULTIPLIER = 5 - - class FaissIndex(EmbeddingIndex): def __init__(self, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None): self.index = faiss.IndexFlatL2(dimension) @@ -140,7 +131,7 @@ class FaissIndex(EmbeddingIndex): raise NotImplementedError("Keyword search is not supported in FAISS") -class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): def __init__(self, config: FaissVectorIOConfig, inference_api: Inference) -> None: self.config = config self.inference_api = inference_api @@ -164,14 +155,8 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): ) self.cache[vector_db.identifier] = index - # Load existing OpenAI vector stores - start_key = OPENAI_VECTOR_STORES_PREFIX - end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" - stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) - - for store_data in stored_openai_stores: - store_info = json.loads(store_data) - self.openai_vector_stores[store_info["id"]] = store_info + # Load existing OpenAI vector stores using the mixin method + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: # Cleanup if needed @@ -234,285 +219,34 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): return await index.query_chunks(query, params) - # OpenAI Vector Stores API endpoints implementation - async def openai_create_vector_store( - self, - name: str | None = None, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorStoreObject: - """Creates a vector store.""" + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to kvstore.""" assert self.kvstore is not None - # store and vector_db have the same id - store_id = name or str(uuid.uuid4()) - created_at = int(time.time()) - - if provider_id is None: - raise ValueError("Provider ID is required") - - if embedding_model is None: - raise ValueError("Embedding model is required") - - # Use provided embedding dimension or default to 384 - if embedding_dimension is None: - raise ValueError("Embedding dimension is required") - - provider_vector_db_id = provider_vector_db_id or store_id - vector_db = VectorDB( - identifier=store_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - provider_resource_id=provider_vector_db_id, - ) - - # Register the vector DB - await self.register_vector_db(vector_db) - - # Create OpenAI vector store metadata - store_info = { - "id": store_id, - "object": "vector_store", - "created_at": created_at, - "name": store_id, - "usage_bytes": 0, - "file_counts": {}, - "status": "completed", - "expires_after": expires_after, - "expires_at": None, - "last_active_at": created_at, - "file_ids": file_ids or [], - "chunking_strategy": chunking_strategy, - } - - # Add provider information to metadata if provided - metadata = metadata or {} - if provider_id: - metadata["provider_id"] = provider_id - if provider_vector_db_id: - metadata["provider_vector_db_id"] = provider_vector_db_id - store_info["metadata"] = metadata - - # Store in kvstore key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.set(key=key, value=json.dumps(store_info)) - # Store in memory cache - self.openai_vector_stores[store_id] = store_info - - return VectorStoreObject( - id=store_id, - created_at=created_at, - name=store_id, - usage_bytes=0, - file_counts={}, - status="completed", - expires_after=expires_after, - expires_at=None, - last_active_at=created_at, - metadata=metadata, - ) - - async def openai_list_vector_stores( - self, - limit: int = 20, - order: str = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - """Returns a list of vector stores.""" - # Get all vector stores - all_stores = list(self.openai_vector_stores.values()) - - # Sort by created_at - reverse_order = order == "desc" - all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order) - - # Apply cursor-based pagination - if after: - after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1) - if after_index >= 0: - all_stores = all_stores[after_index + 1 :] - - if before: - before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores)) - all_stores = all_stores[:before_index] - - # Apply limit - limited_stores = all_stores[:limit] - # Convert to VectorStoreObject instances - data = [VectorStoreObject(**store) for store in limited_stores] - - # Determine pagination info - has_more = len(all_stores) > limit - first_id = data[0].id if data else None - last_id = data[-1].id if data else None - - return VectorStoreListResponse( - data=data, - has_more=has_more, - first_id=first_id, - last_id=last_id, - ) - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - """Retrieves a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - store_info = self.openai_vector_stores[vector_store_id] - return VectorStoreObject(**store_info) - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - """Modifies a vector store.""" + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from kvstore.""" assert self.kvstore is not None - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") + start_key = OPENAI_VECTOR_STORES_PREFIX + end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff" + stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key) - store_info = self.openai_vector_stores[vector_store_id].copy() + stores = {} + for store_data in stored_openai_stores: + store_info = json.loads(store_data) + stores[store_info["id"]] = store_info + return stores - # Update fields if provided - if name is not None: - store_info["name"] = name - if expires_after is not None: - store_info["expires_after"] = expires_after - if metadata is not None: - store_info["metadata"] = metadata - - # Update last_active_at - store_info["last_active_at"] = int(time.time()) - - # Save to kvstore - key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}" + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in kvstore.""" + assert self.kvstore is not None + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.set(key=key, value=json.dumps(store_info)) - # Update in-memory cache - self.openai_vector_stores[vector_store_id] = store_info - - return VectorStoreObject(**store_info) - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - """Delete a vector store.""" + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from kvstore.""" assert self.kvstore is not None - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - # Delete from kvstore - key = f"{OPENAI_VECTOR_STORES_PREFIX}{vector_store_id}" + key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" await self.kvstore.delete(key) - - # Delete from in-memory cache - del self.openai_vector_stores[vector_store_id] - - # Also delete the underlying vector DB - try: - await self.unregister_vector_db(vector_store_id) - except Exception as e: - logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") - - return VectorStoreDeleteResponse( - id=vector_store_id, - deleted=True, - ) - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int = 10, - ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", - ) -> VectorStoreSearchResponse: - """Search for chunks in a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - if isinstance(query, list): - search_query = " ".join(query) - else: - search_query = query - - try: - score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0 - params = { - "max_chunks": max_num_results * CHUNK_MULTIPLIER, - "score_threshold": score_threshold, - "mode": search_mode, - } - # TODO: Add support for ranking_options.ranker - - response = await self.query_chunks( - vector_db_id=vector_store_id, - query=search_query, - params=params, - ) - - # Convert response to OpenAI format - data = [] - for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)): - # Apply score based filtering - if score < score_threshold: - continue - - # Apply filters if provided - if filters: - # Simple metadata filtering - if not self._matches_filters(chunk.metadata, filters): - continue - - chunk_data = { - "id": f"chunk_{i}", - "object": "vector_store.search_result", - "score": score, - "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content), - "metadata": chunk.metadata, - } - data.append(chunk_data) - if len(data) >= max_num_results: - break - - return VectorStoreSearchResponse( - search_query=search_query, - data=data, - has_more=False, # For simplicity, we don't implement pagination here - next_page=None, - ) - - except Exception as e: - logger.error(f"Error searching vector store {vector_store_id}: {e}") - # Return empty results on error - return VectorStoreSearchResponse( - search_query=search_query, - data=[], - has_more=False, - next_page=None, - ) - - def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: - """Check if metadata matches the provided filters.""" - for key, value in filters.items(): - if key not in metadata: - return False - if metadata[key] != value: - return False - return True diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 0a0f0e653..02f04e766 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -10,9 +10,8 @@ import json import logging import sqlite3 import struct -import time import uuid -from typing import Any, Literal +from typing import Any import numpy as np import sqlite_vec @@ -24,12 +23,9 @@ from llama_stack.apis.vector_io import ( Chunk, QueryChunksResponse, VectorIO, - VectorStoreDeleteResponse, - VectorStoreListResponse, - VectorStoreObject, - VectorStoreSearchResponse, ) from llama_stack.providers.datatypes import VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex logger = logging.getLogger(__name__) @@ -39,11 +35,6 @@ VECTOR_SEARCH = "vector" KEYWORD_SEARCH = "keyword" SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH} -# Constants for OpenAI vector stores (similar to faiss) -VERSION = "v3" -OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::" -CHUNK_MULTIPLIER = 5 - def serialize_vector(vector: list[float]) -> bytes: """Serialize a list of floats into a compact binary representation.""" @@ -303,7 +294,7 @@ class SQLiteVecIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) -class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): +class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate): """ A VectorIO implementation using SQLite + sqlite_vec. This class handles vector database registration (with metadata stored in a table named `vector_dbs`) @@ -340,15 +331,12 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): # Load any existing vector DB registrations. cur.execute("SELECT metadata FROM vector_dbs") vector_db_rows = cur.fetchall() - # Load any existing OpenAI vector stores. - cur.execute("SELECT metadata FROM openai_vector_stores") - openai_store_rows = cur.fetchall() - return vector_db_rows, openai_store_rows + return vector_db_rows finally: cur.close() connection.close() - vector_db_rows, openai_store_rows = await asyncio.to_thread(_setup_connection) + vector_db_rows = await asyncio.to_thread(_setup_connection) # Load existing vector DBs for row in vector_db_rows: @@ -359,11 +347,8 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): ) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) - # Load existing OpenAI vector stores - for row in openai_store_rows: - store_data = row[0] - store_info = json.loads(store_data) - self.openai_vector_stores[store_info["id"]] = store_info + # Load existing OpenAI vector stores using the mixin method + self.openai_vector_stores = await self._load_openai_vector_stores() async def shutdown(self) -> None: # nothing to do since we don't maintain a persistent connection @@ -409,6 +394,87 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): await asyncio.to_thread(_delete_vector_db_from_registry) + # OpenAI Vector Store Mixin abstract method implementations + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to SQLite database.""" + + def _store(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute( + "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)", + (store_id, json.dumps(store_info)), + ) + connection.commit() + except Exception as e: + logger.error(f"Error saving openai vector store {store_id}: {e}") + raise + finally: + cur.close() + connection.close() + + try: + await asyncio.to_thread(_store) + except Exception as e: + logger.error(f"Error saving openai vector store {store_id}: {e}") + raise + + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from SQLite database.""" + + def _load(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute("SELECT metadata FROM openai_vector_stores") + rows = cur.fetchall() + return rows + finally: + cur.close() + connection.close() + + rows = await asyncio.to_thread(_load) + stores = {} + for row in rows: + store_data = row[0] + store_info = json.loads(store_data) + stores[store_info["id"]] = store_info + return stores + + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in SQLite database.""" + + def _update(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute( + "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?", + (json.dumps(store_info), store_id), + ) + connection.commit() + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_update) + + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from SQLite database.""" + + def _delete(): + connection = _create_sqlite_connection(self.config.db_path) + cur = connection.cursor() + try: + cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,)) + connection.commit() + finally: + cur.close() + connection.close() + + await asyncio.to_thread(_delete) + async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: if vector_db_id not in self.cache: raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") @@ -423,318 +489,6 @@ class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): raise ValueError(f"Vector DB {vector_db_id} not found") return await self.cache[vector_db_id].query_chunks(query, params) - async def openai_create_vector_store( - self, - name: str | None = None, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = 384, - provider_id: str | None = None, - provider_vector_db_id: str | None = None, - ) -> VectorStoreObject: - """Creates a vector store.""" - # store and vector_db have the same id - store_id = name or str(uuid.uuid4()) - created_at = int(time.time()) - - if provider_id is None: - raise ValueError("Provider ID is required") - - if embedding_model is None: - raise ValueError("Embedding model is required") - - # Use provided embedding dimension or default to 384 - if embedding_dimension is None: - raise ValueError("Embedding dimension is required") - - provider_vector_db_id = provider_vector_db_id or store_id - vector_db = VectorDB( - identifier=store_id, - embedding_dimension=embedding_dimension, - embedding_model=embedding_model, - provider_id=provider_id, - provider_resource_id=provider_vector_db_id, - ) - - # Register the vector DB - await self.register_vector_db(vector_db) - - # Create OpenAI vector store metadata - store_info = { - "id": store_id, - "object": "vector_store", - "created_at": created_at, - "name": store_id, - "usage_bytes": 0, - "file_counts": {}, - "status": "completed", - "expires_after": expires_after, - "expires_at": None, - "last_active_at": created_at, - "file_ids": file_ids or [], - "chunking_strategy": chunking_strategy, - } - - # Add provider information to metadata if provided - metadata = metadata or {} - if provider_id: - metadata["provider_id"] = provider_id - if provider_vector_db_id: - metadata["provider_vector_db_id"] = provider_vector_db_id - store_info["metadata"] = metadata - - # Store in SQLite database - def _store_openai_vector_store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)", - (store_id, json.dumps(store_info)), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_store_openai_vector_store) - - # Store in memory cache - self.openai_vector_stores[store_id] = store_info - - return VectorStoreObject( - id=store_id, - created_at=created_at, - name=store_id, - usage_bytes=0, - file_counts={}, - status="completed", - expires_after=expires_after, - expires_at=None, - last_active_at=created_at, - metadata=metadata, - ) - - async def openai_list_vector_stores( - self, - limit: int = 20, - order: str = "desc", - after: str | None = None, - before: str | None = None, - ) -> VectorStoreListResponse: - """Returns a list of vector stores.""" - # Get all vector stores - all_stores = list(self.openai_vector_stores.values()) - - # Sort by created_at - reverse_order = order == "desc" - all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order) - - # Apply cursor-based pagination - if after: - after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1) - if after_index >= 0: - all_stores = all_stores[after_index + 1 :] - - if before: - before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores)) - all_stores = all_stores[:before_index] - - # Apply limit - limited_stores = all_stores[:limit] - # Convert to VectorStoreObject instances - data = [VectorStoreObject(**store) for store in limited_stores] - - # Determine pagination info - has_more = len(all_stores) > limit - first_id = data[0].id if data else None - last_id = data[-1].id if data else None - - return VectorStoreListResponse( - data=data, - has_more=has_more, - first_id=first_id, - last_id=last_id, - ) - - async def openai_retrieve_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreObject: - """Retrieves a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - store_info = self.openai_vector_stores[vector_store_id] - return VectorStoreObject(**store_info) - - async def openai_update_vector_store( - self, - vector_store_id: str, - name: str | None = None, - expires_after: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - ) -> VectorStoreObject: - """Modifies a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - store_info = self.openai_vector_stores[vector_store_id].copy() - - # Update fields if provided - if name is not None: - store_info["name"] = name - if expires_after is not None: - store_info["expires_after"] = expires_after - if metadata is not None: - store_info["metadata"] = metadata - - # Update last_active_at - store_info["last_active_at"] = int(time.time()) - - # Save to SQLite database - def _update_openai_vector_store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute( - "UPDATE openai_vector_stores SET metadata = ? WHERE id = ?", - (json.dumps(store_info), vector_store_id), - ) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_update_openai_vector_store) - - # Update in-memory cache - self.openai_vector_stores[vector_store_id] = store_info - - return VectorStoreObject(**store_info) - - async def openai_delete_vector_store( - self, - vector_store_id: str, - ) -> VectorStoreDeleteResponse: - """Delete a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - # Delete from SQLite database - def _delete_openai_vector_store(): - connection = _create_sqlite_connection(self.config.db_path) - cur = connection.cursor() - try: - cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (vector_store_id,)) - connection.commit() - finally: - cur.close() - connection.close() - - await asyncio.to_thread(_delete_openai_vector_store) - - # Delete from in-memory cache - del self.openai_vector_stores[vector_store_id] - - # Also delete the underlying vector DB - try: - await self.unregister_vector_db(vector_store_id) - except Exception as e: - logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") - - return VectorStoreDeleteResponse( - id=vector_store_id, - deleted=True, - ) - - async def openai_search_vector_store( - self, - vector_store_id: str, - query: str | list[str], - filters: dict[str, Any] | None = None, - max_num_results: int = 10, - ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", - ) -> VectorStoreSearchResponse: - """Search for chunks in a vector store.""" - if vector_store_id not in self.openai_vector_stores: - raise ValueError(f"Vector store {vector_store_id} not found") - - if isinstance(query, list): - search_query = " ".join(query) - else: - search_query = query - - try: - score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0 - params = { - "max_chunks": max_num_results * CHUNK_MULTIPLIER, - "score_threshold": score_threshold, - "mode": search_mode, - } - # TODO: Add support for ranking_options.ranker - - response = await self.query_chunks( - vector_db_id=vector_store_id, - query=search_query, - params=params, - ) - - # Convert response to OpenAI format - data = [] - for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)): - # Apply score based filtering - if score < score_threshold: - continue - - # Apply filters if provided - if filters: - # Simple metadata filtering - if not self._matches_filters(chunk.metadata, filters): - continue - - chunk_data = { - "id": f"chunk_{i}", - "object": "vector_store.search_result", - "score": score, - "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content), - "metadata": chunk.metadata, - } - data.append(chunk_data) - if len(data) >= max_num_results: - break - - return VectorStoreSearchResponse( - search_query=search_query, - data=data, - has_more=False, # For simplicity, we don't implement pagination here - next_page=None, - ) - - except Exception as e: - logger.error(f"Error searching vector store {vector_store_id}: {e}") - # Return empty results on error - return VectorStoreSearchResponse( - search_query=search_query, - data=[], - has_more=False, - next_page=None, - ) - - def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: - """Check if metadata matches the provided filters.""" - for key, value in filters.items(): - if key not in metadata: - return False - if metadata[key] != value: - return False - return True - def generate_chunk_id(document_id: str, chunk_text: str) -> str: """Generate a unique chunk ID using a hash of document ID and chunk text.""" diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 50b864fcd..de41f388c 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -189,7 +189,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 9e1ee9f6d..9360ef36a 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -187,7 +187,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index a11513174..cff62bff5 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -189,7 +189,7 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_create_vector_store( self, - name: str | None = None, + name: str, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py new file mode 100644 index 000000000..345171828 --- /dev/null +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -0,0 +1,354 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import logging +import time +import uuid +from abc import ABC, abstractmethod +from typing import Any, Literal + +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorStoreDeleteResponse, + VectorStoreListResponse, + VectorStoreObject, + VectorStoreSearchResponse, +) + +logger = logging.getLogger(__name__) + +# Constants for OpenAI vector stores +CHUNK_MULTIPLIER = 5 + + +class OpenAIVectorStoreMixin(ABC): + """ + Mixin class that provides common OpenAI Vector Store API implementation. + Providers need to implement the abstract storage methods and maintain + an openai_vector_stores in-memory cache. + """ + + # These should be provided by the implementing class + openai_vector_stores: dict[str, dict[str, Any]] + + @abstractmethod + async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Save vector store metadata to persistent storage.""" + pass + + @abstractmethod + async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: + """Load all vector store metadata from persistent storage.""" + pass + + @abstractmethod + async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: + """Update vector store metadata in persistent storage.""" + pass + + @abstractmethod + async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: + """Delete vector store metadata from persistent storage.""" + pass + + @abstractmethod + async def register_vector_db(self, vector_db: VectorDB) -> None: + """Register a vector database (provider-specific implementation).""" + pass + + @abstractmethod + async def unregister_vector_db(self, vector_db_id: str) -> None: + """Unregister a vector database (provider-specific implementation).""" + pass + + @abstractmethod + async def query_chunks( + self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None + ) -> QueryChunksResponse: + """Query chunks from a vector database (provider-specific implementation).""" + pass + + async def openai_create_vector_store( + self, + name: str, + file_ids: list[str] | None = None, + expires_after: dict[str, Any] | None = None, + chunking_strategy: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + embedding_model: str | None = None, + embedding_dimension: int | None = 384, + provider_id: str | None = None, + provider_vector_db_id: str | None = None, + ) -> VectorStoreObject: + """Creates a vector store.""" + print("IN OPENAI VECTOR STORE MIXIN, openai_create_vector_store") + # store and vector_db have the same id + store_id = name or str(uuid.uuid4()) + created_at = int(time.time()) + + if provider_id is None: + raise ValueError("Provider ID is required") + + if embedding_model is None: + raise ValueError("Embedding model is required") + + # Use provided embedding dimension or default to 384 + if embedding_dimension is None: + raise ValueError("Embedding dimension is required") + + provider_vector_db_id = provider_vector_db_id or store_id + vector_db = VectorDB( + identifier=store_id, + embedding_dimension=embedding_dimension, + embedding_model=embedding_model, + provider_id=provider_id, + provider_resource_id=provider_vector_db_id, + ) + from rich.pretty import pprint + + print("VECTOR DB") + pprint(vector_db) + + # Register the vector DB + await self.register_vector_db(vector_db) + + # Create OpenAI vector store metadata + store_info = { + "id": store_id, + "object": "vector_store", + "created_at": created_at, + "name": store_id, + "usage_bytes": 0, + "file_counts": {}, + "status": "completed", + "expires_after": expires_after, + "expires_at": None, + "last_active_at": created_at, + "file_ids": file_ids or [], + "chunking_strategy": chunking_strategy, + } + + # Add provider information to metadata if provided + metadata = metadata or {} + if provider_id: + metadata["provider_id"] = provider_id + if provider_vector_db_id: + metadata["provider_vector_db_id"] = provider_vector_db_id + store_info["metadata"] = metadata + + # Save to persistent storage (provider-specific) + await self._save_openai_vector_store(store_id, store_info) + + # Store in memory cache + self.openai_vector_stores[store_id] = store_info + + return VectorStoreObject( + id=store_id, + created_at=created_at, + name=store_id, + usage_bytes=0, + file_counts={}, + status="completed", + expires_after=expires_after, + expires_at=None, + last_active_at=created_at, + metadata=metadata, + ) + + async def openai_list_vector_stores( + self, + limit: int = 20, + order: str = "desc", + after: str | None = None, + before: str | None = None, + ) -> VectorStoreListResponse: + """Returns a list of vector stores.""" + # Get all vector stores + all_stores = list(self.openai_vector_stores.values()) + + # Sort by created_at + reverse_order = order == "desc" + all_stores.sort(key=lambda x: x["created_at"], reverse=reverse_order) + + # Apply cursor-based pagination + if after: + after_index = next((i for i, store in enumerate(all_stores) if store["id"] == after), -1) + if after_index >= 0: + all_stores = all_stores[after_index + 1 :] + + if before: + before_index = next((i for i, store in enumerate(all_stores) if store["id"] == before), len(all_stores)) + all_stores = all_stores[:before_index] + + # Apply limit + limited_stores = all_stores[:limit] + # Convert to VectorStoreObject instances + data = [VectorStoreObject(**store) for store in limited_stores] + + # Determine pagination info + has_more = len(all_stores) > limit + first_id = data[0].id if data else None + last_id = data[-1].id if data else None + + return VectorStoreListResponse( + data=data, + has_more=has_more, + first_id=first_id, + last_id=last_id, + ) + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + """Retrieves a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + store_info = self.openai_vector_stores[vector_store_id] + return VectorStoreObject(**store_info) + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + """Modifies a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + store_info = self.openai_vector_stores[vector_store_id].copy() + + # Update fields if provided + if name is not None: + store_info["name"] = name + if expires_after is not None: + store_info["expires_after"] = expires_after + if metadata is not None: + store_info["metadata"] = metadata + + # Update last_active_at + store_info["last_active_at"] = int(time.time()) + + # Save to persistent storage (provider-specific) + await self._update_openai_vector_store(vector_store_id, store_info) + + # Update in-memory cache + self.openai_vector_stores[vector_store_id] = store_info + + return VectorStoreObject(**store_info) + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + """Delete a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + # Delete from persistent storage (provider-specific) + await self._delete_openai_vector_store_from_storage(vector_store_id) + + # Delete from in-memory cache + del self.openai_vector_stores[vector_store_id] + + # Also delete the underlying vector DB + try: + await self.unregister_vector_db(vector_store_id) + except Exception as e: + logger.warning(f"Failed to delete underlying vector DB {vector_store_id}: {e}") + + return VectorStoreDeleteResponse( + id=vector_store_id, + deleted=True, + ) + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int = 10, + ranking_options: dict[str, Any] | None = None, + rewrite_query: bool = False, + search_mode: Literal["keyword", "vector", "hybrid"] = "vector", + ) -> VectorStoreSearchResponse: + """Search for chunks in a vector store.""" + if vector_store_id not in self.openai_vector_stores: + raise ValueError(f"Vector store {vector_store_id} not found") + + if isinstance(query, list): + search_query = " ".join(query) + else: + search_query = query + + try: + score_threshold = ranking_options.get("score_threshold", 0.0) if ranking_options else 0.0 + params = { + "max_chunks": max_num_results * CHUNK_MULTIPLIER, + "score_threshold": score_threshold, + "mode": search_mode, + } + # TODO: Add support for ranking_options.ranker + + response = await self.query_chunks( + vector_db_id=vector_store_id, + query=search_query, + params=params, + ) + + # Convert response to OpenAI format + data = [] + for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)): + # Apply score based filtering + if score < score_threshold: + continue + + # Apply filters if provided + if filters: + # Simple metadata filtering + if not self._matches_filters(chunk.metadata, filters): + continue + + chunk_data = { + "id": f"chunk_{i}", + "object": "vector_store.search_result", + "score": score, + "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content), + "metadata": chunk.metadata, + } + data.append(chunk_data) + if len(data) >= max_num_results: + break + + return VectorStoreSearchResponse( + search_query=search_query, + data=data, + has_more=False, # For simplicity, we don't implement pagination here + next_page=None, + ) + + except Exception as e: + logger.error(f"Error searching vector store {vector_store_id}: {e}") + # Return empty results on error + return VectorStoreSearchResponse( + search_query=search_query, + data=[], + has_more=False, + next_page=None, + ) + + def _matches_filters(self, metadata: dict[str, Any], filters: dict[str, Any]) -> bool: + """Check if metadata matches the provided filters.""" + for key, value in filters.items(): + if key not in metadata: + return False + if metadata[key] != value: + return False + return True From de37a04c3eae88a0750dca98b0d960e3f3e9d022 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 11 Jun 2025 17:30:34 -0700 Subject: [PATCH 3/9] fix: set appropriate defaults for params (#2434) Setting defaults to be `| None` else they get marked as required params in open-api spec. --- docs/_static/llama-stack-spec.html | 8 +++----- docs/_static/llama-stack-spec.yaml | 6 ++---- llama_stack/apis/vector_io/vector_io.py | 8 ++++---- llama_stack/distribution/routers/vector_io.py | 8 ++++---- .../remote/vector_io/chroma/chroma.py | 11 +++++------ .../remote/vector_io/milvus/milvus.py | 11 +++++------ .../remote/vector_io/qdrant/qdrant.py | 11 +++++------ .../utils/memory/openai_vector_store_mixin.py | 19 +++++++++++++------ 8 files changed, 41 insertions(+), 41 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 18f4f49b9..a1a3217c4 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3318,7 +3318,7 @@ "name": "limit", "in": "query", "description": "A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20.", - "required": true, + "required": false, "schema": { "type": "integer" } @@ -3327,7 +3327,7 @@ "name": "order", "in": "query", "description": "Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order.", - "required": true, + "required": false, "schema": { "type": "string" } @@ -13128,9 +13128,7 @@ }, "additionalProperties": false, "required": [ - "query", - "max_num_results", - "rewrite_query" + "query" ], "title": "OpenaiSearchVectorStoreRequest" }, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 4dbcd3ac9..15593d060 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2323,7 +2323,7 @@ paths: description: >- A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20. - required: true + required: false schema: type: integer - name: order @@ -2331,7 +2331,7 @@ paths: description: >- Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order. - required: true + required: false schema: type: string - name: after @@ -9189,8 +9189,6 @@ components: additionalProperties: false required: - query - - max_num_results - - rewrite_query title: OpenaiSearchVectorStoreRequest VectorStoreSearchResponse: type: object diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 5f54539a4..c14a88c5e 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -193,8 +193,8 @@ class VectorIO(Protocol): @webmethod(route="/openai/v1/vector_stores", method="GET") async def openai_list_vector_stores( self, - limit: int = 20, - order: str = "desc", + limit: int | None = 20, + order: str | None = "desc", after: str | None = None, before: str | None = None, ) -> VectorStoreListResponse: @@ -256,9 +256,9 @@ class VectorIO(Protocol): vector_store_id: str, query: str | list[str], filters: dict[str, Any] | None = None, - max_num_results: int = 10, + max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, + rewrite_query: bool | None = False, ) -> VectorStoreSearchResponse: """Search for chunks in a vector store. diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index 30b19c436..601109963 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -151,8 +151,8 @@ class VectorIORouter(VectorIO): async def openai_list_vector_stores( self, - limit: int = 20, - order: str = "desc", + limit: int | None = 20, + order: str | None = "desc", after: str | None = None, before: str | None = None, ) -> VectorStoreListResponse: @@ -239,9 +239,9 @@ class VectorIORouter(VectorIO): vector_store_id: str, query: str | list[str], filters: dict[str, Any] | None = None, - max_num_results: int = 10, + max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, + rewrite_query: bool | None = False, ) -> VectorStoreSearchResponse: logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") # Route based on vector store ID diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index de41f388c..5f5be539d 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -6,7 +6,7 @@ import asyncio import json import logging -from typing import Any, Literal +from typing import Any from urllib.parse import urlparse import chromadb @@ -203,8 +203,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_list_vector_stores( self, - limit: int = 20, - order: str = "desc", + limit: int | None = 20, + order: str | None = "desc", after: str | None = None, before: str | None = None, ) -> VectorStoreListResponse: @@ -236,9 +236,8 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): vector_store_id: str, query: str | list[str], filters: dict[str, Any] | None = None, - max_num_results: int = 10, + max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", + rewrite_query: bool | None = False, ) -> VectorStoreSearchResponse: raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index 9360ef36a..ae59af599 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -9,7 +9,7 @@ import hashlib import logging import os import uuid -from typing import Any, Literal +from typing import Any from numpy.typing import NDArray from pymilvus import MilvusClient @@ -201,8 +201,8 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_list_vector_stores( self, - limit: int = 20, - order: str = "desc", + limit: int | None = 20, + order: str | None = "desc", after: str | None = None, before: str | None = None, ) -> VectorStoreListResponse: @@ -234,10 +234,9 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): vector_store_id: str, query: str | list[str], filters: dict[str, Any] | None = None, - max_num_results: int = 10, + max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", + rewrite_query: bool | None = False, ) -> VectorStoreSearchResponse: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index cff62bff5..9e802fd6a 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -6,7 +6,7 @@ import logging import uuid -from typing import Any, Literal +from typing import Any from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models @@ -203,8 +203,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): async def openai_list_vector_stores( self, - limit: int = 20, - order: str = "desc", + limit: int | None = 20, + order: str | None = "desc", after: str | None = None, before: str | None = None, ) -> VectorStoreListResponse: @@ -236,9 +236,8 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): vector_store_id: str, query: str | list[str], filters: dict[str, Any] | None = None, - max_num_results: int = 10, + max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", + rewrite_query: bool | None = False, ) -> VectorStoreSearchResponse: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 345171828..c2a63f3c5 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -8,7 +8,7 @@ import logging import time import uuid from abc import ABC, abstractmethod -from typing import Any, Literal +from typing import Any from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( @@ -161,12 +161,15 @@ class OpenAIVectorStoreMixin(ABC): async def openai_list_vector_stores( self, - limit: int = 20, - order: str = "desc", + limit: int | None = 20, + order: str | None = "desc", after: str | None = None, before: str | None = None, ) -> VectorStoreListResponse: """Returns a list of vector stores.""" + limit = limit or 20 + order = order or "desc" + # Get all vector stores all_stores = list(self.openai_vector_stores.values()) @@ -274,12 +277,16 @@ class OpenAIVectorStoreMixin(ABC): vector_store_id: str, query: str | list[str], filters: dict[str, Any] | None = None, - max_num_results: int = 10, + max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, - rewrite_query: bool = False, - search_mode: Literal["keyword", "vector", "hybrid"] = "vector", + rewrite_query: bool | None = False, + # search_mode: Literal["keyword", "vector", "hybrid"] = "vector", ) -> VectorStoreSearchResponse: """Search for chunks in a vector store.""" + # TODO: Add support in the API for this + search_mode = "vector" + max_num_results = max_num_results or 10 + if vector_store_id not in self.openai_vector_stores: raise ValueError(f"Vector store {vector_store_id} not found") From eb04731750688de26f0aba7291199f8c9b1521b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 12 Jun 2025 16:14:32 +0200 Subject: [PATCH 4/9] ci: fix external provider test (#2438) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The test wasn't using the correct virtual environment. Also augment the console width for logs. Signed-off-by: Sébastien Han --- .github/workflows/test-external-providers.yml | 10 ++++++---- llama_stack/distribution/distribution.py | 1 + llama_stack/log.py | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-providers.yml index 06ab7cf3c..cdf18fab7 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external-providers.yml @@ -45,20 +45,22 @@ jobs: - name: Build distro from config file run: | - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml - name: Start Llama Stack server in background if: ${{ matrix.image-type }} == 'venv' env: INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" run: | - uv run pip list - nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & + # Use the virtual environment created by the build step (name comes from build config) + source ci-test/bin/activate + uv pip list + nohup llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 & - name: Wait for Llama Stack server to be ready run: | for i in {1..30}; do - if ! grep -q "remote::custom_ollama from /home/runner/.llama/providers.d/remote/inference/custom_ollama.yaml" server.log; then + if ! grep -q "Successfully loaded external provider remote::custom_ollama" server.log; then echo "Waiting for Llama Stack server to load the provider..." sleep 1 else diff --git a/llama_stack/distribution/distribution.py b/llama_stack/distribution/distribution.py index b860d15ab..e37b2c443 100644 --- a/llama_stack/distribution/distribution.py +++ b/llama_stack/distribution/distribution.py @@ -180,6 +180,7 @@ def get_provider_registry( if provider_type_key in ret[api]: logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}") ret[api][provider_type_key] = spec + logger.info(f"Successfully loaded external provider {provider_type_key}") except yaml.YAMLError as yaml_err: logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}") raise yaml_err diff --git a/llama_stack/log.py b/llama_stack/log.py index f4184710a..c14967f0a 100644 --- a/llama_stack/log.py +++ b/llama_stack/log.py @@ -115,7 +115,7 @@ def parse_environment_config(env_config: str) -> dict[str, int]: class CustomRichHandler(RichHandler): def __init__(self, *args, **kwargs): - kwargs["console"] = Console(width=120) + kwargs["console"] = Console(width=150) super().__init__(*args, **kwargs) def emit(self, record): From 35c2817d0ae94ab8eda837a1f1b4eef0f9a6ae60 Mon Sep 17 00:00:00 2001 From: Ibrahim Haroon <99413953+Ibrahim-Haroon@users.noreply.github.com> Date: Thu, 12 Jun 2025 11:23:59 -0400 Subject: [PATCH 5/9] fix(weaviate): handle case where distance is 0 by setting score to infinity (#2415) # What does this PR do? Fixes provider weaviate `query_vector` function for when the distance between the query embedding and an embedding within the vector db is 0 (identical vectors). Catches `ZeroDivisionError` and then sets `score` to infinity, which represent maximum similarity. Closes [#2381] ## Test Plan Checkout this PR Execute this code and there will no longer be a `ZeroDivisionError` exception ``` from llama_stack_client import LlamaStackClient base_url = "http://localhost:8321" client = LlamaStackClient(base_url=base_url) models = client.models.list() embedding_model = ( em := next(m for m in models if m.model_type == "embedding") ).identifier embedding_dimension = 384 _ = client.vector_dbs.register( vector_db_id="foo_db", embedding_model=embedding_model, embedding_dimension=embedding_dimension, provider_id="weaviate", ) chunk = { "content": "foo", "mime_type": "text/plain", "metadata": { "document_id": "foo-id" } } client.vector_io.insert(vector_db_id="foo_db", chunks=[chunk]) client.vector_io.query(vector_db_id="foo_db", query="foo") ``` --- .../remote/vector_io/weaviate/weaviate.py | 2 +- tests/integration/vector_io/test_vector_io.py | 33 +++++++++++++++++++ 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index e6fe8ccd3..6f2027dad 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -76,7 +76,7 @@ class WeaviateIndex(EmbeddingIndex): continue chunks.append(chunk) - scores.append(1.0 / doc.metadata.distance) + scores.append(1.0 / doc.metadata.distance if doc.metadata.distance != 0 else float("inf")) return QueryChunksResponse(chunks=chunks, scores=scores) diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index f1cac9701..f550cf666 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -154,3 +154,36 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e assert len(response.chunks) > 0 assert response.chunks[0].metadata["document_id"] == "doc1" assert response.chunks[0].metadata["source"] == "precomputed" + + +def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(client_with_empty_registry, embedding_model_id): + vector_db_id = "test_precomputed_embeddings_db" + client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=384, + ) + + chunks_with_embeddings = [ + Chunk( + content="duplicate", + metadata={"document_id": "doc1", "source": "precomputed"}, + embedding=[0.1] * 384, + ), + ] + + client_with_empty_registry.vector_io.insert( + vector_db_id=vector_db_id, + chunks=chunks_with_embeddings, + ) + + response = client_with_empty_registry.vector_io.query( + vector_db_id=vector_db_id, + query="duplicate", + ) + + # Verify the top result is the expected document + assert response is not None + assert len(response.chunks) > 0 + assert response.chunks[0].metadata["document_id"] == "doc1" + assert response.chunks[0].metadata["source"] == "precomputed" From 0bc1747ed8ac06360e5c9df6086e1c67a7e939bf Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 12 Jun 2025 15:34:22 -0700 Subject: [PATCH 6/9] feat: update search for vector_stores (#2441) Updated the `search` functionality return response to match openai. ## Test Plan ``` pytest -sv --stack-config=http://localhost:8321 tests/integration/vector_io/test_openai_vector_stores.py --embedding-model all-MiniLM-L6-v2 ``` --- docs/_static/llama-stack-spec.html | 91 +++++++++++----- docs/_static/llama-stack-spec.yaml | 55 ++++++++-- llama_stack/apis/vector_io/vector_io.py | 23 +++- llama_stack/distribution/routers/vector_io.py | 4 +- .../remote/vector_io/chroma/chroma.py | 4 +- .../remote/vector_io/milvus/milvus.py | 4 +- .../remote/vector_io/qdrant/qdrant.py | 4 +- .../utils/memory/openai_vector_store_mixin.py | 55 +++++++--- .../vector_io/test_openai_vector_stores.py | 102 ++++++++++-------- 9 files changed, 236 insertions(+), 106 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index a1a3217c4..96de04ec9 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3864,7 +3864,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/VectorStoreSearchResponse" + "$ref": "#/components/schemas/VectorStoreSearchResponsePage" } } } @@ -13132,7 +13132,70 @@ ], "title": "OpenaiSearchVectorStoreRequest" }, + "VectorStoreContent": { + "type": "object", + "properties": { + "type": { + "type": "string", + "const": "text" + }, + "text": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "type", + "text" + ], + "title": "VectorStoreContent" + }, "VectorStoreSearchResponse": { + "type": "object", + "properties": { + "file_id": { + "type": "string" + }, + "filename": { + "type": "string" + }, + "score": { + "type": "number" + }, + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "number" + }, + { + "type": "boolean" + } + ] + } + }, + "content": { + "type": "array", + "items": { + "$ref": "#/components/schemas/VectorStoreContent" + } + } + }, + "additionalProperties": false, + "required": [ + "file_id", + "filename", + "score", + "content" + ], + "title": "VectorStoreSearchResponse", + "description": "Response from searching a vector store." + }, + "VectorStoreSearchResponsePage": { "type": "object", "properties": { "object": { @@ -13145,29 +13208,7 @@ "data": { "type": "array", "items": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } + "$ref": "#/components/schemas/VectorStoreSearchResponse" } }, "has_more": { @@ -13185,7 +13226,7 @@ "data", "has_more" ], - "title": "VectorStoreSearchResponse", + "title": "VectorStoreSearchResponsePage", "description": "Response from searching a vector store." }, "OpenaiUpdateVectorStoreRequest": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 15593d060..b2fe870be 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2734,7 +2734,7 @@ paths: content: application/json: schema: - $ref: '#/components/schemas/VectorStoreSearchResponse' + $ref: '#/components/schemas/VectorStoreSearchResponsePage' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -9190,7 +9190,48 @@ components: required: - query title: OpenaiSearchVectorStoreRequest + VectorStoreContent: + type: object + properties: + type: + type: string + const: text + text: + type: string + additionalProperties: false + required: + - type + - text + title: VectorStoreContent VectorStoreSearchResponse: + type: object + properties: + file_id: + type: string + filename: + type: string + score: + type: number + attributes: + type: object + additionalProperties: + oneOf: + - type: string + - type: number + - type: boolean + content: + type: array + items: + $ref: '#/components/schemas/VectorStoreContent' + additionalProperties: false + required: + - file_id + - filename + - score + - content + title: VectorStoreSearchResponse + description: Response from searching a vector store. + VectorStoreSearchResponsePage: type: object properties: object: @@ -9201,15 +9242,7 @@ components: data: type: array items: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object + $ref: '#/components/schemas/VectorStoreSearchResponse' has_more: type: boolean default: false @@ -9221,7 +9254,7 @@ components: - search_query - data - has_more - title: VectorStoreSearchResponse + title: VectorStoreSearchResponsePage description: Response from searching a vector store. OpenaiUpdateVectorStoreRequest: type: object diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index c14a88c5e..1c8ae4dab 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -8,7 +8,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Protocol, runtime_checkable +from typing import Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field @@ -96,13 +96,30 @@ class VectorStoreSearchRequest(BaseModel): rewrite_query: bool = False +@json_schema_type +class VectorStoreContent(BaseModel): + type: Literal["text"] + text: str + + @json_schema_type class VectorStoreSearchResponse(BaseModel): """Response from searching a vector store.""" + file_id: str + filename: str + score: float + attributes: dict[str, str | float | bool] | None = None + content: list[VectorStoreContent] + + +@json_schema_type +class VectorStoreSearchResponsePage(BaseModel): + """Response from searching a vector store.""" + object: str = "vector_store.search_results.page" search_query: str - data: list[dict[str, Any]] + data: list[VectorStoreSearchResponse] has_more: bool = False next_page: str | None = None @@ -259,7 +276,7 @@ class VectorIO(Protocol): max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, rewrite_query: bool | None = False, - ) -> VectorStoreSearchResponse: + ) -> VectorStoreSearchResponsePage: """Search for chunks in a vector store. Searches a vector store for relevant chunks based on a query and optional file attribute filters. diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index 601109963..3d65aef24 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -17,7 +17,7 @@ from llama_stack.apis.vector_io import ( VectorStoreDeleteResponse, VectorStoreListResponse, VectorStoreObject, - VectorStoreSearchResponse, + VectorStoreSearchResponsePage, ) from llama_stack.log import get_logger from llama_stack.providers.datatypes import RoutingTable @@ -242,7 +242,7 @@ class VectorIORouter(VectorIO): max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, rewrite_query: bool | None = False, - ) -> VectorStoreSearchResponse: + ) -> VectorStoreSearchResponsePage: logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") # Route based on vector store ID provider = self.routing_table.get_provider_impl(vector_store_id) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 5f5be539d..0d8451eb2 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import ( VectorStoreDeleteResponse, VectorStoreListResponse, VectorStoreObject, - VectorStoreSearchResponse, + VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig @@ -239,5 +239,5 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, rewrite_query: bool | None = False, - ) -> VectorStoreSearchResponse: + ) -> VectorStoreSearchResponsePage: raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma") diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index ae59af599..8ae74aedc 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -23,7 +23,7 @@ from llama_stack.apis.vector_io import ( VectorStoreDeleteResponse, VectorStoreListResponse, VectorStoreObject, - VectorStoreSearchResponse, + VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig @@ -237,7 +237,7 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, rewrite_query: bool | None = False, - ) -> VectorStoreSearchResponse: + ) -> VectorStoreSearchResponsePage: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 9e802fd6a..10f3b5b0d 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -21,7 +21,7 @@ from llama_stack.apis.vector_io import ( VectorStoreDeleteResponse, VectorStoreListResponse, VectorStoreObject, - VectorStoreSearchResponse, + VectorStoreSearchResponsePage, ) from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig @@ -239,5 +239,5 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): max_num_results: int | None = 10, ranking_options: dict[str, Any] | None = None, rewrite_query: bool | None = False, - ) -> VectorStoreSearchResponse: + ) -> VectorStoreSearchResponsePage: raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant") diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index c2a63f3c5..398075f57 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -13,10 +13,12 @@ from typing import Any from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import ( QueryChunksResponse, + VectorStoreContent, VectorStoreDeleteResponse, VectorStoreListResponse, VectorStoreObject, VectorStoreSearchResponse, + VectorStoreSearchResponsePage, ) logger = logging.getLogger(__name__) @@ -85,7 +87,6 @@ class OpenAIVectorStoreMixin(ABC): provider_vector_db_id: str | None = None, ) -> VectorStoreObject: """Creates a vector store.""" - print("IN OPENAI VECTOR STORE MIXIN, openai_create_vector_store") # store and vector_db have the same id store_id = name or str(uuid.uuid4()) created_at = int(time.time()) @@ -281,7 +282,7 @@ class OpenAIVectorStoreMixin(ABC): ranking_options: dict[str, Any] | None = None, rewrite_query: bool | None = False, # search_mode: Literal["keyword", "vector", "hybrid"] = "vector", - ) -> VectorStoreSearchResponse: + ) -> VectorStoreSearchResponsePage: """Search for chunks in a vector store.""" # TODO: Add support in the API for this search_mode = "vector" @@ -312,7 +313,7 @@ class OpenAIVectorStoreMixin(ABC): # Convert response to OpenAI format data = [] - for i, (chunk, score) in enumerate(zip(response.chunks, response.scores, strict=False)): + for chunk, score in zip(response.chunks, response.scores, strict=False): # Apply score based filtering if score < score_threshold: continue @@ -323,18 +324,46 @@ class OpenAIVectorStoreMixin(ABC): if not self._matches_filters(chunk.metadata, filters): continue - chunk_data = { - "id": f"chunk_{i}", - "object": "vector_store.search_result", - "score": score, - "content": chunk.content.content if hasattr(chunk.content, "content") else str(chunk.content), - "metadata": chunk.metadata, - } - data.append(chunk_data) + # content is InterleavedContent + if isinstance(chunk.content, str): + content = [ + VectorStoreContent( + type="text", + text=chunk.content, + ) + ] + elif isinstance(chunk.content, list): + # TODO: Add support for other types of content + content = [ + VectorStoreContent( + type="text", + text=item.text, + ) + for item in chunk.content + if item.type == "text" + ] + else: + if chunk.content.type != "text": + raise ValueError(f"Unsupported content type: {chunk.content.type}") + content = [ + VectorStoreContent( + type="text", + text=chunk.content.text, + ) + ] + + response_data_item = VectorStoreSearchResponse( + file_id=chunk.metadata.get("file_id", ""), + filename=chunk.metadata.get("filename", ""), + score=score, + attributes=chunk.metadata, + content=content, + ) + data.append(response_data_item) if len(data) >= max_num_results: break - return VectorStoreSearchResponse( + return VectorStoreSearchResponsePage( search_query=search_query, data=data, has_more=False, # For simplicity, we don't implement pagination here @@ -344,7 +373,7 @@ class OpenAIVectorStoreMixin(ABC): except Exception as e: logger.error(f"Error searching vector store {vector_store_id}: {e}") # Return empty results on error - return VectorStoreSearchResponse( + return VectorStoreSearchResponsePage( search_query=search_query, data=[], has_more=False, diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index d67c35e69..a67582a07 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -34,6 +34,13 @@ def openai_client(client_with_models): return OpenAI(base_url=base_url, api_key="fake") +@pytest.fixture(params=["openai_client"]) # , "llama_stack_client"]) +def compat_client(request, client_with_models): + if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI client tests not supported with library client") + return request.getfixturevalue(request.param) + + @pytest.fixture(scope="session") def sample_chunks(): return [ @@ -57,29 +64,29 @@ def sample_chunks(): @pytest.fixture(scope="function") -def openai_client_with_empty_stores(openai_client): +def compat_client_with_empty_stores(compat_client): def clear_vector_stores(): # List and delete all existing vector stores try: - response = openai_client.vector_stores.list() + response = compat_client.vector_stores.list() for store in response.data: - openai_client.vector_stores.delete(vector_store_id=store.id) + compat_client.vector_stores.delete(vector_store_id=store.id) except Exception: # If the API is not available or fails, just continue logger.warning("Failed to clear vector stores") pass clear_vector_stores() - yield openai_client + yield compat_client # Clean up after the test clear_vector_stores() -def test_openai_create_vector_store(openai_client_with_empty_stores, client_with_models): +def test_openai_create_vector_store(compat_client_with_empty_stores, client_with_models): """Test creating a vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - client = openai_client_with_empty_stores + client = compat_client_with_empty_stores # Create a vector store vector_store = client.vector_stores.create( @@ -96,11 +103,11 @@ def test_openai_create_vector_store(openai_client_with_empty_stores, client_with assert hasattr(vector_store, "created_at") -def test_openai_list_vector_stores(openai_client_with_empty_stores, client_with_models): +def test_openai_list_vector_stores(compat_client_with_empty_stores, client_with_models): """Test listing vector stores using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - client = openai_client_with_empty_stores + client = compat_client_with_empty_stores # Create a few vector stores store1 = client.vector_stores.create(name="store1", metadata={"type": "test"}) @@ -123,11 +130,11 @@ def test_openai_list_vector_stores(openai_client_with_empty_stores, client_with_ assert len(limited_response.data) == 1 -def test_openai_retrieve_vector_store(openai_client_with_empty_stores, client_with_models): +def test_openai_retrieve_vector_store(compat_client_with_empty_stores, client_with_models): """Test retrieving a specific vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - client = openai_client_with_empty_stores + client = compat_client_with_empty_stores # Create a vector store created_store = client.vector_stores.create(name="retrieve_test_store", metadata={"purpose": "retrieval_test"}) @@ -142,11 +149,11 @@ def test_openai_retrieve_vector_store(openai_client_with_empty_stores, client_wi assert retrieved_store.object == "vector_store" -def test_openai_update_vector_store(openai_client_with_empty_stores, client_with_models): +def test_openai_update_vector_store(compat_client_with_empty_stores, client_with_models): """Test modifying a vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - client = openai_client_with_empty_stores + client = compat_client_with_empty_stores # Create a vector store created_store = client.vector_stores.create(name="original_name", metadata={"version": "1.0"}) @@ -165,11 +172,11 @@ def test_openai_update_vector_store(openai_client_with_empty_stores, client_with assert modified_store.last_active_at > created_store.last_active_at -def test_openai_delete_vector_store(openai_client_with_empty_stores, client_with_models): +def test_openai_delete_vector_store(compat_client_with_empty_stores, client_with_models): """Test deleting a vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - client = openai_client_with_empty_stores + client = compat_client_with_empty_stores # Create a vector store created_store = client.vector_stores.create(name="delete_test_store", metadata={"purpose": "deletion_test"}) @@ -187,11 +194,11 @@ def test_openai_delete_vector_store(openai_client_with_empty_stores, client_with client.vector_stores.retrieve(vector_store_id=created_store.id) -def test_openai_vector_store_search_empty(openai_client_with_empty_stores, client_with_models): +def test_openai_vector_store_search_empty(compat_client_with_empty_stores, client_with_models): """Test searching an empty vector store using OpenAI API.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - client = openai_client_with_empty_stores + client = compat_client_with_empty_stores # Create a vector store vector_store = client.vector_stores.create(name="search_test_store", metadata={"purpose": "search_testing"}) @@ -208,15 +215,15 @@ def test_openai_vector_store_search_empty(openai_client_with_empty_stores, clien assert search_response.has_more is False -def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client_with_models, sample_chunks): +def test_openai_vector_store_with_chunks(compat_client_with_empty_stores, client_with_models, sample_chunks): """Test vector store functionality with actual chunks using both OpenAI and native APIs.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - openai_client = openai_client_with_empty_stores + compat_client = compat_client_with_empty_stores llama_client = client_with_models # Create a vector store using OpenAI API - vector_store = openai_client.vector_stores.create(name="chunks_test_store", metadata={"purpose": "chunks_testing"}) + vector_store = compat_client.vector_stores.create(name="chunks_test_store", metadata={"purpose": "chunks_testing"}) # Insert chunks using the native LlamaStack API (since OpenAI API doesn't have direct chunk insertion) llama_client.vector_io.insert( @@ -225,7 +232,7 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client ) # Search using OpenAI API - search_response = openai_client.vector_stores.search( + search_response = compat_client.vector_stores.search( vector_store_id=vector_store.id, query="What is Python programming language?", max_num_results=3 ) assert search_response is not None @@ -233,18 +240,19 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client # The top result should be about Python (doc1) top_result = search_response.data[0] - assert "python" in top_result.content.lower() or "programming" in top_result.content.lower() - assert top_result.metadata["document_id"] == "doc1" + top_content = top_result.content[0].text + assert "python" in top_content.lower() or "programming" in top_content.lower() + assert top_result.attributes["document_id"] == "doc1" # Test filtering by metadata - filtered_search = openai_client.vector_stores.search( + filtered_search = compat_client.vector_stores.search( vector_store_id=vector_store.id, query="artificial intelligence", filters={"topic": "ai"}, max_num_results=5 ) assert filtered_search is not None # All results should have topic "ai" for result in filtered_search.data: - assert result.metadata["topic"] == "ai" + assert result.attributes["topic"] == "ai" @pytest.mark.parametrize( @@ -257,18 +265,18 @@ def test_openai_vector_store_with_chunks(openai_client_with_empty_stores, client ], ) def test_openai_vector_store_search_relevance( - openai_client_with_empty_stores, client_with_models, sample_chunks, test_case + compat_client_with_empty_stores, client_with_models, sample_chunks, test_case ): """Test that OpenAI vector store search returns relevant results for different queries.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - openai_client = openai_client_with_empty_stores + compat_client = compat_client_with_empty_stores llama_client = client_with_models query, expected_doc_id, expected_topic = test_case # Create a vector store - vector_store = openai_client.vector_stores.create( + vector_store = compat_client.vector_stores.create( name=f"relevance_test_{expected_doc_id}", metadata={"purpose": "relevance_testing"} ) @@ -279,7 +287,7 @@ def test_openai_vector_store_search_relevance( ) # Search using OpenAI API - search_response = openai_client.vector_stores.search( + search_response = compat_client.vector_stores.search( vector_store_id=vector_store.id, query=query, max_num_results=4 ) @@ -288,8 +296,9 @@ def test_openai_vector_store_search_relevance( # The top result should match the expected document top_result = search_response.data[0] - assert top_result.metadata["document_id"] == expected_doc_id - assert top_result.metadata["topic"] == expected_topic + + assert top_result.attributes["document_id"] == expected_doc_id + assert top_result.attributes["topic"] == expected_topic # Verify score is included and reasonable assert isinstance(top_result.score, int | float) @@ -297,16 +306,16 @@ def test_openai_vector_store_search_relevance( def test_openai_vector_store_search_with_ranking_options( - openai_client_with_empty_stores, client_with_models, sample_chunks + compat_client_with_empty_stores, client_with_models, sample_chunks ): """Test OpenAI vector store search with ranking options.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - openai_client = openai_client_with_empty_stores + compat_client = compat_client_with_empty_stores llama_client = client_with_models # Create a vector store - vector_store = openai_client.vector_stores.create( + vector_store = compat_client.vector_stores.create( name="ranking_test_store", metadata={"purpose": "ranking_testing"} ) @@ -318,7 +327,7 @@ def test_openai_vector_store_search_with_ranking_options( # Search with ranking options threshold = 0.1 - search_response = openai_client.vector_stores.search( + search_response = compat_client.vector_stores.search( vector_store_id=vector_store.id, query="machine learning and artificial intelligence", max_num_results=3, @@ -334,16 +343,16 @@ def test_openai_vector_store_search_with_ranking_options( def test_openai_vector_store_search_with_high_score_filter( - openai_client_with_empty_stores, client_with_models, sample_chunks + compat_client_with_empty_stores, client_with_models, sample_chunks ): """Test that searching with text very similar to a document and high score threshold returns only that document.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - openai_client = openai_client_with_empty_stores + compat_client = compat_client_with_empty_stores llama_client = client_with_models # Create a vector store - vector_store = openai_client.vector_stores.create( + vector_store = compat_client.vector_stores.create( name="high_score_filter_test", metadata={"purpose": "high_score_filtering"} ) @@ -358,7 +367,7 @@ def test_openai_vector_store_search_with_high_score_filter( query = "Python is a high-level programming language with code readability and fewer lines than C++ or Java" # picking up thrshold to be slightly higher than the second result - search_response = openai_client.vector_stores.search( + search_response = compat_client.vector_stores.search( vector_store_id=vector_store.id, query=query, max_num_results=3, @@ -367,7 +376,7 @@ def test_openai_vector_store_search_with_high_score_filter( threshold = search_response.data[1].score + 0.0001 # we expect only one result with the requested threshold - search_response = openai_client.vector_stores.search( + search_response = compat_client.vector_stores.search( vector_store_id=vector_store.id, query=query, max_num_results=10, # Allow more results but expect filtering @@ -379,25 +388,26 @@ def test_openai_vector_store_search_with_high_score_filter( # The top result should be the Python document (doc1) top_result = search_response.data[0] - assert top_result.metadata["document_id"] == "doc1" - assert top_result.metadata["topic"] == "programming" + assert top_result.attributes["document_id"] == "doc1" + assert top_result.attributes["topic"] == "programming" assert top_result.score >= threshold # Verify the content contains Python-related terms - assert "python" in top_result.content.lower() or "programming" in top_result.content.lower() + top_content = top_result.content[0].text + assert "python" in top_content.lower() or "programming" in top_content.lower() def test_openai_vector_store_search_with_max_num_results( - openai_client_with_empty_stores, client_with_models, sample_chunks + compat_client_with_empty_stores, client_with_models, sample_chunks ): """Test OpenAI vector store search with max_num_results.""" skip_if_provider_doesnt_support_openai_vector_stores(client_with_models) - openai_client = openai_client_with_empty_stores + compat_client = compat_client_with_empty_stores llama_client = client_with_models # Create a vector store - vector_store = openai_client.vector_stores.create( + vector_store = compat_client.vector_stores.create( name="max_num_results_test_store", metadata={"purpose": "max_num_results_testing"} ) @@ -408,7 +418,7 @@ def test_openai_vector_store_search_with_max_num_results( ) # Search with max_num_results - search_response = openai_client.vector_stores.search( + search_response = compat_client.vector_stores.search( vector_store_id=vector_store.id, query="machine learning and artificial intelligence", max_num_results=2, From fef670b02404606da05f83513ae40c8780f5b544 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Thu, 12 Jun 2025 16:30:23 -0700 Subject: [PATCH 7/9] feat: update openai tests to work with both clients (#2442) https://github.com/meta-llama/llama-stack-client-python/pull/238 updated llama-stack-client to also support Open AI endpoints for embeddings, files, vector-stores. This updates the test to test all configs -- openai sdk, llama stack sdk and library-as-client. --- .../utils/memory/openai_vector_store_mixin.py | 5 -- .../inference/test_openai_embeddings.py | 54 ++++++++++--------- .../vector_io/test_openai_vector_stores.py | 5 +- 3 files changed, 30 insertions(+), 34 deletions(-) diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index 398075f57..7d8163ed2 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -109,11 +109,6 @@ class OpenAIVectorStoreMixin(ABC): provider_id=provider_id, provider_resource_id=provider_vector_db_id, ) - from rich.pretty import pprint - - print("VECTOR DB") - pprint(vector_db) - # Register the vector DB await self.register_vector_db(vector_db) diff --git a/tests/integration/inference/test_openai_embeddings.py b/tests/integration/inference/test_openai_embeddings.py index 759556257..90a91a206 100644 --- a/tests/integration/inference/test_openai_embeddings.py +++ b/tests/integration/inference/test_openai_embeddings.py @@ -34,11 +34,15 @@ def skip_if_model_doesnt_support_variable_dimensions(model_id): pytest.skip("{model_id} does not support variable output embedding dimensions") -def skip_if_model_doesnt_support_openai_embeddings(client_with_models, model_id): - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI embeddings are not supported when testing with library client yet.") +@pytest.fixture(params=["openai_client", "llama_stack_client"]) +def compat_client(request, client_with_models): + if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("OpenAI client tests not supported with library client") + return request.getfixturevalue(request.param) - provider = provider_from_model(client_with_models, model_id) + +def skip_if_model_doesnt_support_openai_embeddings(client, model_id): + provider = provider_from_model(client, model_id) if provider.provider_type in ( "inline::meta-reference", "remote::bedrock", @@ -58,13 +62,13 @@ def openai_client(client_with_models): return OpenAI(base_url=base_url, api_key="fake") -def test_openai_embeddings_single_string(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_single_string(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with a single string input.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) input_text = "Hello, world!" - response = openai_client.embeddings.create( + response = compat_client.embeddings.create( model=embedding_model_id, input=input_text, encoding_format="float", @@ -80,13 +84,13 @@ def test_openai_embeddings_single_string(openai_client, client_with_models, embe assert all(isinstance(x, float) for x in response.data[0].embedding) -def test_openai_embeddings_multiple_strings(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_multiple_strings(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with multiple string inputs.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) input_texts = ["Hello, world!", "How are you today?", "This is a test."] - response = openai_client.embeddings.create( + response = compat_client.embeddings.create( model=embedding_model_id, input=input_texts, ) @@ -103,13 +107,13 @@ def test_openai_embeddings_multiple_strings(openai_client, client_with_models, e assert all(isinstance(x, float) for x in embedding_data.embedding) -def test_openai_embeddings_with_encoding_format_float(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_with_encoding_format_float(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with float encoding format.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) input_text = "Test encoding format" - response = openai_client.embeddings.create( + response = compat_client.embeddings.create( model=embedding_model_id, input=input_text, encoding_format="float", @@ -121,7 +125,7 @@ def test_openai_embeddings_with_encoding_format_float(openai_client, client_with assert all(isinstance(x, float) for x in response.data[0].embedding) -def test_openai_embeddings_with_dimensions(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_with_dimensions(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with custom dimensions parameter.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_variable_dimensions(embedding_model_id) @@ -129,7 +133,7 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em input_text = "Test dimensions parameter" dimensions = 16 - response = openai_client.embeddings.create( + response = compat_client.embeddings.create( model=embedding_model_id, input=input_text, dimensions=dimensions, @@ -142,14 +146,14 @@ def test_openai_embeddings_with_dimensions(openai_client, client_with_models, em assert len(response.data[0].embedding) > 0 -def test_openai_embeddings_with_user_parameter(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_with_user_parameter(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with user parameter.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) input_text = "Test user parameter" user_id = "test-user-123" - response = openai_client.embeddings.create( + response = compat_client.embeddings.create( model=embedding_model_id, input=input_text, user=user_id, @@ -161,41 +165,41 @@ def test_openai_embeddings_with_user_parameter(openai_client, client_with_models assert len(response.data[0].embedding) > 0 -def test_openai_embeddings_empty_list_error(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_empty_list_error(compat_client, client_with_models, embedding_model_id): """Test that empty list input raises an appropriate error.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) with pytest.raises(Exception): # noqa: B017 - openai_client.embeddings.create( + compat_client.embeddings.create( model=embedding_model_id, input=[], ) -def test_openai_embeddings_invalid_model_error(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_invalid_model_error(compat_client, client_with_models, embedding_model_id): """Test that invalid model ID raises an appropriate error.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) with pytest.raises(Exception): # noqa: B017 - openai_client.embeddings.create( + compat_client.embeddings.create( model="invalid-model-id", input="Test text", ) -def test_openai_embeddings_different_inputs_different_outputs(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_different_inputs_different_outputs(compat_client, client_with_models, embedding_model_id): """Test that different inputs produce different embeddings.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) input_text1 = "This is the first text" input_text2 = "This is completely different content" - response1 = openai_client.embeddings.create( + response1 = compat_client.embeddings.create( model=embedding_model_id, input=input_text1, ) - response2 = openai_client.embeddings.create( + response2 = compat_client.embeddings.create( model=embedding_model_id, input=input_text2, ) @@ -208,7 +212,7 @@ def test_openai_embeddings_different_inputs_different_outputs(openai_client, cli assert embedding1 != embedding2 -def test_openai_embeddings_with_encoding_format_base64(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_with_encoding_format_base64(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with base64 encoding format.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) skip_if_model_doesnt_support_variable_dimensions(embedding_model_id) @@ -216,7 +220,7 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit input_text = "Test base64 encoding format" dimensions = 12 - response = openai_client.embeddings.create( + response = compat_client.embeddings.create( model=embedding_model_id, input=input_text, encoding_format="base64", @@ -241,13 +245,13 @@ def test_openai_embeddings_with_encoding_format_base64(openai_client, client_wit assert all(isinstance(x, float) for x in embedding_floats) -def test_openai_embeddings_base64_batch_processing(openai_client, client_with_models, embedding_model_id): +def test_openai_embeddings_base64_batch_processing(compat_client, client_with_models, embedding_model_id): """Test OpenAI embeddings endpoint with base64 encoding for batch processing.""" skip_if_model_doesnt_support_openai_embeddings(client_with_models, embedding_model_id) input_texts = ["First text for base64", "Second text for base64", "Third text for base64"] - response = openai_client.embeddings.create( + response = compat_client.embeddings.create( model=embedding_model_id, input=input_texts, encoding_format="base64", diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index a67582a07..d9c4199ed 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -17,9 +17,6 @@ logger = logging.getLogger(__name__) def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI vector stores are not supported when testing with library client yet.") - vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"] for p in vector_io_providers: if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]: @@ -34,7 +31,7 @@ def openai_client(client_with_models): return OpenAI(base_url=base_url, api_key="fake") -@pytest.fixture(params=["openai_client"]) # , "llama_stack_client"]) +@pytest.fixture(params=["openai_client", "llama_stack_client"]) def compat_client(request, client_with_models): if request.param == "openai_client" and isinstance(client_with_models, LlamaStackAsLibraryClient): pytest.skip("OpenAI client tests not supported with library client") From ddaee42650e394cd6fb906e38d0d284c4cfb9813 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Fri, 13 Jun 2025 01:04:08 -0700 Subject: [PATCH 8/9] test: Update integration-tests.yml (#2443) Added `vector_io` to the CI integration tests. --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 7aa8b5807..210a0e95b 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -24,7 +24,7 @@ jobs: matrix: # Listing tests manually since some of them currently fail # TODO: generate matrix list from tests/integration when fixed - test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime] + test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime, vector_io] client-type: [library, http] python-version: ["3.10", "3.11", "3.12"] fail-fast: false # we want to run all tests regardless of failure From e2e15ebb6c271bec5cd03b1f6d7561b992514fc5 Mon Sep 17 00:00:00 2001 From: grs Date: Fri, 13 Jun 2025 04:13:41 -0400 Subject: [PATCH 9/9] feat(auth): allow token to be provided for use against jwks endpoint (#2394) Though the jwks endpoint does not usually require authentication, it does in a kubernetes cluster. While the cluster can be configured to allow anonymous access to that endpoint, this avoids the need to do so. --- .github/workflows/integration-auth-tests.yml | 26 +------ docs/source/distributions/configuration.md | 77 ++++++++++--------- .../distribution/server/auth_providers.py | 6 +- tests/unit/server/test_auth.py | 50 ++++++++++++ 4 files changed, 99 insertions(+), 60 deletions(-) diff --git a/.github/workflows/integration-auth-tests.yml b/.github/workflows/integration-auth-tests.yml index a3a746246..e0f3ff2e8 100644 --- a/.github/workflows/integration-auth-tests.yml +++ b/.github/workflows/integration-auth-tests.yml @@ -52,30 +52,7 @@ jobs: run: | kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack - kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token - cat <> $GITHUB_ENV echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV + echo "TOKEN=$(cat llama-stack-auth-token)" >> $GITHUB_ENV - name: Set Kube Auth Config and run server env: @@ -101,7 +79,7 @@ jobs: EOF yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml - yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml + yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml cat $run_dir/run.yaml nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 & diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index de99b6576..a48083055 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -56,10 +56,10 @@ shields: [] server: port: 8321 auth: - provider_type: "kubernetes" + provider_type: "oauth2_token" config: - api_server_url: "https://kubernetes.default.svc" - ca_cert_path: "/path/to/ca.crt" + jwks: + uri: "https://my-token-issuing-svc.com/jwks" ``` Let's break this down into the different sections. The first section specifies the set of APIs that the stack server will serve: @@ -132,16 +132,52 @@ The server supports multiple authentication providers: #### OAuth 2.0/OpenID Connect Provider with Kubernetes -The Kubernetes cluster must be configured to use a service account for authentication. +The server can be configured to use service account tokens for authorization, validating these against the Kubernetes API server, e.g.: +```yaml +server: + auth: + provider_type: "oauth2_token" + config: + jwks: + uri: "https://kubernetes.default.svc:8443/openid/v1/jwks" + token: "${env.TOKEN:}" + key_recheck_period: 3600 + tls_cafile: "/path/to/ca.crt" + issuer: "https://kubernetes.default.svc" + audience: "https://kubernetes.default.svc" +``` + +To find your cluster's jwks uri (from which the public key(s) to verify the token signature are obtained), run: +``` +kubectl get --raw /.well-known/openid-configuration| jq -r .jwks_uri +``` + +For the tls_cafile, you can use the CA certificate of the OIDC provider: +```bash +kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' +``` + +For the issuer, you can use the OIDC provider's URL: +```bash +kubectl get --raw /.well-known/openid-configuration| jq .issuer +``` + +The audience can be obtained from a token, e.g. run: +```bash +kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud +``` + +The jwks token is used to authorize access to the jwks endpoint. You can obtain a token by running: ```bash kubectl create namespace llama-stack kubectl create serviceaccount llama-stack-auth -n llama-stack -kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token +export TOKEN=$(cat llama-stack-auth-token) ``` -Make sure the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests +Alternatively, you can configure the jwks endpoint to allow anonymous access. To do this, make sure +the `kube-apiserver` runs with `--anonymous-auth=true` to allow unauthenticated requests and that the correct RoleBinding is created to allow the service account to access the necessary resources. If that is not the case, you can create a RoleBinding for the service account to access the necessary resources: @@ -175,35 +211,6 @@ And then apply the configuration: kubectl apply -f allow-anonymous-openid.yaml ``` -Validates tokens against the Kubernetes API server through the OIDC provider: -```yaml -server: - auth: - provider_type: "oauth2_token" - config: - jwks: - uri: "https://kubernetes.default.svc" - key_recheck_period: 3600 - tls_cafile: "/path/to/ca.crt" - issuer: "https://kubernetes.default.svc" - audience: "https://kubernetes.default.svc" -``` - -To find your cluster's audience, run: -```bash -kubectl create token default --duration=1h | cut -d. -f2 | base64 -d | jq .aud -``` - -For the issuer, you can use the OIDC provider's URL: -```bash -kubectl get --raw /.well-known/openid-configuration| jq .issuer -``` - -For the tls_cafile, you can use the CA certificate of the OIDC provider: -```bash -kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}' -``` - The provider extracts user information from the JWT token: - Username from the `sub` claim becomes a role - Kubernetes groups become teams diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 942ff8a18..98e51c25a 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) class OAuth2JWKSConfig(BaseModel): # The JWKS URI for collecting public keys uri: str + token: str | None = Field(default=None, description="token to authorise access to jwks") key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates") @@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider): if self.config.jwks is None: raise ValueError("JWKS is not configured") if time.time() - self._jwks_at > self.config.jwks.key_recheck_period: + headers = {} + if self.config.jwks.token: + headers["Authorization"] = f"Bearer {self.config.jwks.token}" verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls async with httpx.AsyncClient(verify=verify) as client: - res = await client.get(self.config.jwks.uri, timeout=5) + res = await client.get(self.config.jwks.uri, timeout=5, headers=headers) res.raise_for_status() jwks_data = res.json()["keys"] updated = {} diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index e159aefd1..4410048c5 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -345,6 +345,56 @@ def test_invalid_oauth2_authentication(oauth2_client, invalid_token): assert "Invalid JWT token" in response.json()["error"]["message"] +async def mock_auth_jwks_response(*args, **kwargs): + if "headers" not in kwargs or "Authorization" not in kwargs["headers"]: + return MockResponse(401, {}) + authz = kwargs["headers"]["Authorization"] + if authz != "Bearer my-jwks-token": + return MockResponse(401, {}) + return await mock_jwks_response(args, kwargs) + + +@pytest.fixture +def oauth2_app_with_jwks_token(): + app = FastAPI() + auth_config = AuthenticationConfig( + provider_type=AuthProviderType.OAUTH2_TOKEN, + config={ + "jwks": { + "uri": "http://mock-authz-service/token/introspect", + "key_recheck_period": "3600", + "token": "my-jwks-token", + }, + "audience": "llama-stack", + }, + ) + app.add_middleware(AuthenticationMiddleware, auth_config=auth_config) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token): + return TestClient(oauth2_app_with_jwks_token) + + +@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) +def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid): + response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) + assert response.status_code == 401 + + +@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response) +def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid): + response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + def test_get_attributes_from_claims(): claims = { "sub": "my-user",