diff --git a/docs/docs/providers/agents/index.mdx b/docs/docs/providers/agents/index.mdx index 200a3b9ca..1f7e0c788 100644 --- a/docs/docs/providers/agents/index.mdx +++ b/docs/docs/providers/agents/index.mdx @@ -2,7 +2,7 @@ description: | Agents - APIs for creating and interacting with agentic systems. + APIs for creating and interacting with agentic systems. sidebar_label: Agents title: Agents --- @@ -13,6 +13,6 @@ title: Agents Agents - APIs for creating and interacting with agentic systems. +APIs for creating and interacting with agentic systems. This section contains documentation for all available providers for the **agents** API. diff --git a/docs/docs/providers/batches/index.mdx b/docs/docs/providers/batches/index.mdx index 18fd49945..23b7df14b 100644 --- a/docs/docs/providers/batches/index.mdx +++ b/docs/docs/providers/batches/index.mdx @@ -1,15 +1,15 @@ --- description: | The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. + The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation + This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes. + Note: This API is currently under active development and may undergo changes. sidebar_label: Batches title: Batches --- @@ -19,14 +19,14 @@ title: Batches ## Overview The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. +particularly useful for processing large datasets, batch evaluation workflows, and +cost-effective inference at scale. - The API is designed to allow use of openai client libraries for seamless integration. +The API is designed to allow use of openai client libraries for seamless integration. - This API provides the following extensions: - - idempotent batch creation +This API provides the following extensions: + - idempotent batch creation - Note: This API is currently under active development and may undergo changes. +Note: This API is currently under active development and may undergo changes. This section contains documentation for all available providers for the **batches** API. diff --git a/docs/docs/providers/eval/index.mdx b/docs/docs/providers/eval/index.mdx index 3543db246..a6e35d611 100644 --- a/docs/docs/providers/eval/index.mdx +++ b/docs/docs/providers/eval/index.mdx @@ -2,7 +2,7 @@ description: | Evaluations - Llama Stack Evaluation API for running evaluations on model and agent candidates. + Llama Stack Evaluation API for running evaluations on model and agent candidates. sidebar_label: Eval title: Eval --- @@ -13,6 +13,6 @@ title: Eval Evaluations - Llama Stack Evaluation API for running evaluations on model and agent candidates. +Llama Stack Evaluation API for running evaluations on model and agent candidates. This section contains documentation for all available providers for the **eval** API. diff --git a/docs/docs/providers/files/index.mdx b/docs/docs/providers/files/index.mdx index 0b28e9aee..0540c5c3e 100644 --- a/docs/docs/providers/files/index.mdx +++ b/docs/docs/providers/files/index.mdx @@ -2,7 +2,7 @@ description: | Files - This API is used to upload documents that can be used with other Llama Stack APIs. + This API is used to upload documents that can be used with other Llama Stack APIs. sidebar_label: Files title: Files --- @@ -13,6 +13,6 @@ title: Files Files - This API is used to upload documents that can be used with other Llama Stack APIs. +This API is used to upload documents that can be used with other Llama Stack APIs. This section contains documentation for all available providers for the **files** API. diff --git a/docs/docs/providers/inference/index.mdx b/docs/docs/providers/inference/index.mdx index e2d94bfaf..ad050e501 100644 --- a/docs/docs/providers/inference/index.mdx +++ b/docs/docs/providers/inference/index.mdx @@ -2,12 +2,12 @@ description: | Inference - Llama Stack Inference API for generating completions, chat completions, and embeddings. + Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Three kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. - - Rerank models: these models reorder the documents based on their relevance to a query. + This API provides the raw interface to the underlying models. Three kinds of models are supported: + - LLM models: these models generate "raw" and "chat" (conversational) completions. + - Embedding models: these models generate embeddings to be used for semantic search. + - Rerank models: these models reorder the documents based on their relevance to a query. sidebar_label: Inference title: Inference --- @@ -18,11 +18,11 @@ title: Inference Inference - Llama Stack Inference API for generating completions, chat completions, and embeddings. +Llama Stack Inference API for generating completions, chat completions, and embeddings. - This API provides the raw interface to the underlying models. Three kinds of models are supported: - - LLM models: these models generate "raw" and "chat" (conversational) completions. - - Embedding models: these models generate embeddings to be used for semantic search. - - Rerank models: these models reorder the documents based on their relevance to a query. +This API provides the raw interface to the underlying models. Three kinds of models are supported: +- LLM models: these models generate "raw" and "chat" (conversational) completions. +- Embedding models: these models generate embeddings to be used for semantic search. +- Rerank models: these models reorder the documents based on their relevance to a query. This section contains documentation for all available providers for the **inference** API. diff --git a/docs/docs/providers/safety/index.mdx b/docs/docs/providers/safety/index.mdx index 0c13de28c..e7205f4ad 100644 --- a/docs/docs/providers/safety/index.mdx +++ b/docs/docs/providers/safety/index.mdx @@ -2,7 +2,7 @@ description: | Safety - OpenAI-compatible Moderations API. + OpenAI-compatible Moderations API. sidebar_label: Safety title: Safety --- @@ -13,6 +13,6 @@ title: Safety Safety - OpenAI-compatible Moderations API. +OpenAI-compatible Moderations API. This section contains documentation for all available providers for the **safety** API. diff --git a/docs/source/providers/agents/index.md b/docs/source/providers/agents/index.md new file mode 100644 index 000000000..e69de29bb diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md new file mode 100644 index 000000000..e69de29bb diff --git a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py index 491db6d4d..ade1b2dc0 100644 --- a/src/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/src/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import asyncio +import heapq import json from typing import Any from urllib.parse import urlparse @@ -16,6 +17,7 @@ from llama_stack.log import get_logger from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion, EmbeddingIndex, VectorStoreWithIndex +from llama_stack.providers.utils.vector_io.vector_utils import WeightedInMemoryAggregator from llama_stack_api import ( Chunk, Files, @@ -99,8 +101,55 @@ class ChromaIndex(EmbeddingIndex): async def delete(self): await maybe_await(self.client.delete_collection(self.collection.name)) - async def query_keyword(self, query_string: str, k: int, score_threshold: float) -> QueryChunksResponse: - raise NotImplementedError("Keyword search is not supported in Chroma") + async def query_keyword( + self, + query_string: str, + k: int, + score_threshold: float, + ) -> QueryChunksResponse: + """ + Perform keyword search using Chroma's built-in where_document feature. + + Args: + query_string: The text query for keyword search + k: Number of results to return + score_threshold: Minimum similarity score threshold + + Returns: + QueryChunksResponse with combined results + """ + try: + results = await maybe_await( + self.collection.query( + query_texts=[query_string], + where_document={"$contains": query_string}, + n_results=k, + include=["documents", "distances"], + ) + ) + except Exception as e: + log.error(f"Chroma client keyword search failed: {e}") + raise + + distances = results["distances"][0] if results["distances"] else [] + documents = results["documents"][0] if results["documents"] else [] + + chunks = [] + scores = [] + + for dist, doc in zip(distances, documents, strict=False): + doc_data = json.loads(doc) + chunk = Chunk(**doc_data) + + score = 1.0 / (1.0 + float(dist)) if dist is not None else 1.0 + + if score < score_threshold: + continue + + chunks.append(chunk) + scores.append(score) + + return QueryChunksResponse(chunks=chunks, scores=scores) async def delete_chunks(self, chunks_for_deletion: list[ChunkForDeletion]) -> None: """Delete a single chunk from the Chroma collection by its ID.""" @@ -116,7 +165,57 @@ class ChromaIndex(EmbeddingIndex): reranker_type: str, reranker_params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - raise NotImplementedError("Hybrid search is not supported in Chroma") + """ + Hybrid search combining vector similarity and keyword search using configurable reranking. + Args: + embedding: The query embedding vector + query_string: The text query for keyword search + k: Number of results to return + score_threshold: Minimum similarity score threshold + reranker_type: Type of reranker to use ("rrf" or "weighted") + reranker_params: Parameters for the reranker + Returns: + QueryChunksResponse with combined results + """ + if reranker_params is None: + reranker_params = {} + + # Get results from both search methods + vector_response = await self.query_vector(embedding, k, score_threshold) + keyword_response = await self.query_keyword(query_string, k, score_threshold) + + # Convert responses to score dictionaries using chunk_id + vector_scores = { + chunk.chunk_id: score for chunk, score in zip(vector_response.chunks, vector_response.scores, strict=False) + } + keyword_scores = { + chunk.chunk_id: score + for chunk, score in zip(keyword_response.chunks, keyword_response.scores, strict=False) + } + + # Combine scores using the reranking utility + combined_scores = WeightedInMemoryAggregator.combine_search_results( + vector_scores, keyword_scores, reranker_type, reranker_params + ) + + # Efficient top-k selection because it only tracks the k best candidates it's seen so far + top_k_items = heapq.nlargest(k, combined_scores.items(), key=lambda x: x[1]) + + # Filter by score threshold + filtered_items = [(doc_id, score) for doc_id, score in top_k_items if score >= score_threshold] + + # Create a map of chunk_id to chunk for both responses + chunk_map = {c.chunk_id: c for c in vector_response.chunks + keyword_response.chunks} + + # Use the map to look up chunks by their IDs + chunks = [] + scores = [] + for doc_id, score in filtered_items: + if doc_id in chunk_map: + chunks.append(chunk_map[doc_id]) + scores.append(score) + + return QueryChunksResponse(chunks=chunks, scores=scores) class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate): diff --git a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index bbfd60e25..dcf1286c0 100644 --- a/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/src/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -92,6 +92,13 @@ class OpenAIVectorStoreMixin(ABC): self.kvstore = kvstore self._last_file_batch_cleanup_time = 0 self._file_batch_tasks: dict[str, asyncio.Task[None]] = {} + self._vector_store_locks: dict[str, asyncio.Lock] = {} + + def _get_vector_store_lock(self, vector_store_id: str) -> asyncio.Lock: + """Get or create a lock for a specific vector store.""" + if vector_store_id not in self._vector_store_locks: + self._vector_store_locks[vector_store_id] = asyncio.Lock() + return self._vector_store_locks[vector_store_id] async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: """Save vector store metadata to persistent storage.""" @@ -831,16 +838,18 @@ class OpenAIVectorStoreMixin(ABC): await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks) # Update file_ids and file_counts in vector store metadata - store_info = self.openai_vector_stores[vector_store_id].copy() - store_info["file_ids"].append(file_id) - store_info["file_counts"]["total"] += 1 - store_info["file_counts"][vector_store_file_object.status] += 1 + # Use lock to prevent race condition when multiple files are attached concurrently + async with self._get_vector_store_lock(vector_store_id): + store_info = self.openai_vector_stores[vector_store_id].copy() + # Deep copy file_counts to avoid mutating shared dict + store_info["file_counts"] = store_info["file_counts"].copy() + store_info["file_ids"] = store_info["file_ids"].copy() + store_info["file_ids"].append(file_id) + store_info["file_counts"]["total"] += 1 + store_info["file_counts"][vector_store_file_object.status] += 1 - # Save updated vector store to persistent storage - await self._save_openai_vector_store(vector_store_id, store_info) - - # Update vector store in-memory cache - self.openai_vector_stores[vector_store_id] = store_info + # Save updated vector store to persistent storage + await self._save_openai_vector_store(vector_store_id, store_info) return vector_store_file_object diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index 102f3f00c..35d5ca2c8 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -61,6 +61,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "remote::milvus", "remote::pgvector", "remote::weaviate", + "remote::chromadb", ], "hybrid": [ "inline::milvus", @@ -68,6 +69,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode "remote::milvus", "remote::pgvector", "remote::weaviate", + "remote::chromadb", ], } supported_providers = search_mode_support.get(search_mode, [])