diff --git a/.github/ISSUE_TEMPLATE/tech-debt.yml b/.github/ISSUE_TEMPLATE/tech-debt.yml
new file mode 100644
index 000000000..b281b3482
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/tech-debt.yml
@@ -0,0 +1,30 @@
+name: 🔧 Tech Debt
+description: Something that is functional but should be improved or optimizied
+labels: ["tech-debt"]
+body:
+- type: textarea
+ id: tech-debt-explanation
+ attributes:
+ label: 🤔 What is the technical debt you think should be addressed?
+ description: >
+ A clear and concise description of _what_ needs to be addressed - ensure you are describing
+ constitutes [technical debt](https://en.wikipedia.org/wiki/Technical_debt) and is not a bug
+ or feature request.
+ validations:
+ required: true
+
+- type: textarea
+ id: tech-debt-motivation
+ attributes:
+ label: 💡 What is the benefit of addressing this technical debt?
+ description: >
+ A clear and concise description of _why_ this work is needed.
+ validations:
+ required: true
+
+- type: textarea
+ id: other-thoughts
+ attributes:
+ label: Other thoughts
+ description: >
+ Any thoughts about how this may result in complexity in the codebase, or other trade-offs.
diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml
index c46100c38..1a8d6734f 100644
--- a/.github/workflows/integration-tests.yml
+++ b/.github/workflows/integration-tests.yml
@@ -89,7 +89,7 @@ jobs:
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \
- --safety-shield=ollama \
+ --safety-shield=$SAFETY_MODEL \
--color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
deleted file mode 100644
index 79c935005..000000000
--- a/.github/workflows/tests.yml
+++ /dev/null
@@ -1,69 +0,0 @@
-name: auto-tests
-
-on:
- # pull_request:
- workflow_dispatch:
- inputs:
- commit_sha:
- description: 'Specific Commit SHA to trigger on'
- required: false
- default: $GITHUB_SHA # default to the last commit of $GITHUB_REF branch
-
-jobs:
- test-llama-stack-as-library:
- runs-on: ubuntu-latest
- env:
- TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
- FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
- TAVILY_SEARCH_API_KEY: ${{ secrets.TAVILY_SEARCH_API_KEY }}
- strategy:
- matrix:
- provider: [fireworks, together]
- steps:
- - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- with:
- ref: ${{ github.event.inputs.commit_sha }}
-
- - name: Echo commit SHA
- run: |
- echo "Triggered on commit SHA: ${{ github.event.inputs.commit_sha }}"
- git rev-parse HEAD
-
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install -r requirements.txt pytest
- pip install -e .
-
- - name: Build providers
- run: |
- llama stack build --template ${{ matrix.provider }} --image-type venv
-
- - name: Install the latest llama-stack-client & llama-models packages
- run: |
- pip install -e git+https://github.com/meta-llama/llama-stack-client-python.git#egg=llama-stack-client
- pip install -e git+https://github.com/meta-llama/llama-models.git#egg=llama-models
-
- - name: Run client-sdk test
- working-directory: "${{ github.workspace }}"
- env:
- REPORT_OUTPUT: md_report.md
- shell: bash
- run: |
- pip install --upgrade pytest-md-report
- echo "REPORT_FILE=${REPORT_OUTPUT}" >> "$GITHUB_ENV"
-
- export INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
- LLAMA_STACK_CONFIG=./llama_stack/templates/${{ matrix.provider }}/run.yaml pytest --md-report --md-report-verbose=1 ./tests/client-sdk/inference/ --md-report-output "$REPORT_OUTPUT"
-
- - name: Output reports to the job summary
- if: always()
- shell: bash
- run: |
- if [ -f "$REPORT_FILE" ]; then
- echo " Test Report for ${{ matrix.provider }}
" >> $GITHUB_STEP_SUMMARY
- echo "" >> $GITHUB_STEP_SUMMARY
- cat "$REPORT_FILE" >> $GITHUB_STEP_SUMMARY
- echo "" >> $GITHUB_STEP_SUMMARY
- echo " " >> $GITHUB_STEP_SUMMARY
- fi
diff --git a/docs/source/getting_started/detailed_tutorial.md b/docs/source/getting_started/detailed_tutorial.md
index 97e7df774..fc59022f9 100644
--- a/docs/source/getting_started/detailed_tutorial.md
+++ b/docs/source/getting_started/detailed_tutorial.md
@@ -77,7 +77,7 @@ ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template
You can use a container image to run the Llama Stack server. We provide several container images for the server
component that works with different inference providers out of the box. For this guide, we will use
`llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the
-configurations, please check out [this guide](../references/index.md).
+configurations, please check out [this guide](../distributions/building_distro.md).
First lets setup some environment variables and create a local directory to mount into the container’s file system.
```bash
export INFERENCE_MODEL="llama3.2:3b"
diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md
index f3089e615..6734d8315 100644
--- a/docs/source/providers/vector_io/remote_milvus.md
+++ b/docs/source/providers/vector_io/remote_milvus.md
@@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `uri` | `` | No | PydanticUndefined | The URI of the Milvus server |
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `` | No | Strong | The consistency level of the Milvus server |
-| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
+| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
@@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
```yaml
uri: ${env.MILVUS_ENDPOINT}
token: ${env.MILVUS_TOKEN}
+kvstore:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
```
diff --git a/docs/source/providers/vector_io/remote_pgvector.md b/docs/source/providers/vector_io/remote_pgvector.md
index 685b98f37..3e7d6e776 100644
--- a/docs/source/providers/vector_io/remote_pgvector.md
+++ b/docs/source/providers/vector_io/remote_pgvector.md
@@ -40,6 +40,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
| `db` | `str \| None` | No | postgres | |
| `user` | `str \| None` | No | postgres | |
| `password` | `str \| None` | No | mysecretpassword | |
+| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
## Sample Configuration
@@ -49,6 +50,9 @@ port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB}
user: ${env.PGVECTOR_USER}
password: ${env.PGVECTOR_PASSWORD}
+kvstore:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db
```
diff --git a/docs/source/providers/vector_io/remote_weaviate.md b/docs/source/providers/vector_io/remote_weaviate.md
index b7f811c35..d930515d5 100644
--- a/docs/source/providers/vector_io/remote_weaviate.md
+++ b/docs/source/providers/vector_io/remote_weaviate.md
@@ -36,7 +36,9 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
## Sample Configuration
```yaml
-{}
+kvstore:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db
```
diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py
index 0306d9156..2a1370c56 100644
--- a/llama_stack/providers/inline/vector_io/faiss/faiss.py
+++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py
@@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
)
self.cache[vector_db.identifier] = index
- # Load existing OpenAI vector stores using the mixin method
- self.openai_vector_stores = await self._load_openai_vector_stores()
+ # Load existing OpenAI vector stores into the in-memory cache
+ await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# Cleanup if needed
@@ -261,42 +261,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params)
- # OpenAI Vector Store Mixin abstract method implementations
- async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
- """Save vector store metadata to kvstore."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.set(key=key, value=json.dumps(store_info))
- self.openai_vector_stores[store_id] = store_info
-
- async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
- """Load all vector store metadata from kvstore."""
- assert self.kvstore is not None
- start_key = OPENAI_VECTOR_STORES_PREFIX
- end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
- stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
-
- stores = {}
- for store_data in stored_openai_stores:
- store_info = json.loads(store_data)
- stores[store_info["id"]] = store_info
- return stores
-
- async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
- """Update vector store metadata in kvstore."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.set(key=key, value=json.dumps(store_info))
- self.openai_vector_stores[store_id] = store_info
-
- async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
- """Delete vector store metadata from kvstore."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.delete(key)
- if store_id in self.openai_vector_stores:
- del self.openai_vector_stores[store_id]
-
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
index 6acd85c56..771ffa607 100644
--- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
+++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py
@@ -452,8 +452,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
- # load any existing OpenAI vector stores
- self.openai_vector_stores = await self._load_openai_vector_stores()
+ # Load existing OpenAI vector stores into the in-memory cache
+ await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
@@ -501,41 +501,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
- # OpenAI Vector Store Mixin abstract method implementations
- async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
- """Save vector store metadata to SQLite database."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.set(key=key, value=json.dumps(store_info))
- self.openai_vector_stores[store_id] = store_info
-
- async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
- """Load all vector store metadata from SQLite database."""
- assert self.kvstore is not None
- start_key = OPENAI_VECTOR_STORES_PREFIX
- end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
- stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
- stores = {}
- for store_data in stored_openai_stores:
- store_info = json.loads(store_data)
- stores[store_info["id"]] = store_info
- return stores
-
- async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
- """Update vector store metadata in SQLite database."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.set(key=key, value=json.dumps(store_info))
- self.openai_vector_stores[store_id] = store_info
-
- async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
- """Delete vector store metadata from SQLite database."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.delete(key)
- if store_id in self.openai_vector_stores:
- del self.openai_vector_stores[store_id]
-
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
diff --git a/llama_stack/providers/remote/inference/ollama/models.py b/llama_stack/providers/remote/inference/ollama/models.py
index 64ddb23d9..7c0a19a1a 100644
--- a/llama_stack/providers/remote/inference/ollama/models.py
+++ b/llama_stack/providers/remote/inference/ollama/models.py
@@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry,
)
+SAFETY_MODELS_ENTRIES = [
+ # The Llama Guard models don't have their full fp16 versions
+ # so we are going to alias their default version to the canonical SKU
+ build_hf_repo_model_entry(
+ "llama-guard3:8b",
+ CoreModelId.llama_guard_3_8b.value,
+ ),
+ build_hf_repo_model_entry(
+ "llama-guard3:1b",
+ CoreModelId.llama_guard_3_1b.value,
+ ),
+]
+
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16",
@@ -73,16 +86,6 @@ MODEL_ENTRIES = [
"llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value,
),
- # The Llama Guard models don't have their full fp16 versions
- # so we are going to alias their default version to the canonical SKU
- build_hf_repo_model_entry(
- "llama-guard3:8b",
- CoreModelId.llama_guard_3_8b.value,
- ),
- build_hf_repo_model_entry(
- "llama-guard3:1b",
- CoreModelId.llama_guard_3_1b.value,
- ),
ProviderModelEntry(
provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"],
@@ -100,4 +103,4 @@ MODEL_ENTRIES = [
"context_length": 8192,
},
),
-]
+] + SAFETY_MODELS_ENTRIES
diff --git a/llama_stack/providers/remote/vector_io/milvus/config.py b/llama_stack/providers/remote/vector_io/milvus/config.py
index e3f51b4f4..899d3678d 100644
--- a/llama_stack/providers/remote/vector_io/milvus/config.py
+++ b/llama_stack/providers/remote/vector_io/milvus/config.py
@@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field
-from llama_stack.providers.utils.kvstore.config import KVStoreConfig
+from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type
@@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
- kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
+ kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
- return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
+ return {
+ "uri": "${env.MILVUS_ENDPOINT}",
+ "token": "${env.MILVUS_TOKEN}",
+ "kvstore": SqliteKVStoreConfig.sample_run_config(
+ __distro_dir__=__distro_dir__,
+ db_name="milvus_remote_registry.db",
+ ),
+ }
diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py
index a06130fd0..f301942cb 100644
--- a/llama_stack/providers/remote/vector_io/milvus/milvus.py
+++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py
@@ -12,7 +12,7 @@ import re
from typing import Any
from numpy.typing import NDArray
-from pymilvus import DataType, MilvusClient
+from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
@@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
+
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
+ logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
+ # Create schema for vector search
+ schema = self.client.create_schema()
+ schema.add_field(
+ field_name="chunk_id",
+ datatype=DataType.VARCHAR,
+ is_primary=True,
+ max_length=100,
+ )
+ schema.add_field(
+ field_name="content",
+ datatype=DataType.VARCHAR,
+ max_length=65535,
+ enable_analyzer=True, # Enable text analysis for BM25
+ )
+ schema.add_field(
+ field_name="vector",
+ datatype=DataType.FLOAT_VECTOR,
+ dim=len(embeddings[0]),
+ )
+ schema.add_field(
+ field_name="chunk_content",
+ datatype=DataType.JSON,
+ )
+ # Add sparse vector field for BM25 (required by the function)
+ schema.add_field(
+ field_name="sparse",
+ datatype=DataType.SPARSE_FLOAT_VECTOR,
+ )
+
+ # Create indexes
+ index_params = self.client.prepare_index_params()
+ index_params.add_index(
+ field_name="vector",
+ index_type="FLAT",
+ metric_type="COSINE",
+ )
+ # Add index for sparse field (required by BM25 function)
+ index_params.add_index(
+ field_name="sparse",
+ index_type="SPARSE_INVERTED_INDEX",
+ metric_type="BM25",
+ )
+
+ # Add BM25 function for full-text search
+ bm25_function = Function(
+ name="text_bm25_emb",
+ input_field_names=["content"],
+ output_field_names=["sparse"],
+ function_type=FunctionType.BM25,
+ )
+ schema.add_function(bm25_function)
+
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
- dimension=len(embeddings[0]),
- auto_id=True,
+ schema=schema,
+ index_params=index_params,
consistency_level=self.consistency_level,
)
@@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
data.append(
{
"chunk_id": chunk.chunk_id,
+ "content": chunk.content,
"vector": embedding,
"chunk_content": chunk.model_dump(),
+ # sparse field will be handled by BM25 function automatically
}
)
try:
@@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.search,
collection_name=self.collection_name,
data=[embedding],
+ anns_field="vector",
limit=k,
output_fields=["*"],
search_params={"params": {"radius": score_threshold}},
@@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
k: int,
score_threshold: float,
) -> QueryChunksResponse:
- raise NotImplementedError("Keyword search is not supported in Milvus")
+ """
+ Perform BM25-based keyword search using Milvus's built-in full-text search.
+ """
+ try:
+ # Use Milvus's built-in BM25 search
+ search_res = await asyncio.to_thread(
+ self.client.search,
+ collection_name=self.collection_name,
+ data=[query_string], # Raw text query
+ anns_field="sparse", # Use sparse field for BM25
+ output_fields=["chunk_content"], # Output the chunk content
+ limit=k,
+ search_params={
+ "params": {
+ "drop_ratio_search": 0.2, # Ignore low-importance terms
+ }
+ },
+ )
+
+ chunks = []
+ scores = []
+ for res in search_res[0]:
+ chunk = Chunk(**res["entity"]["chunk_content"])
+ chunks.append(chunk)
+ scores.append(res["distance"]) # BM25 score from Milvus
+
+ # Filter by score threshold
+ filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
+ filtered_scores = [score for score in scores if score >= score_threshold]
+
+ return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
+
+ except Exception as e:
+ logger.error(f"Error performing BM25 search: {e}")
+ # Fallback to simple text search
+ return await self._fallback_keyword_search(query_string, k, score_threshold)
+
+ async def _fallback_keyword_search(
+ self,
+ query_string: str,
+ k: int,
+ score_threshold: float,
+ ) -> QueryChunksResponse:
+ """
+ Fallback to simple text search when BM25 search is not available.
+ """
+ # Simple text search using content field
+ search_res = await asyncio.to_thread(
+ self.client.query,
+ collection_name=self.collection_name,
+ filter='content like "%{content}%"',
+ filter_params={"content": query_string},
+ output_fields=["*"],
+ limit=k,
+ )
+ chunks = [Chunk(**res["chunk_content"]) for res in search_res]
+ scores = [1.0] * len(chunks) # Simple binary score for text search
+ return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid(
self,
@@ -179,7 +293,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri)
- self.openai_vector_stores = await self._load_openai_vector_stores()
+ # Load existing OpenAI vector stores into the in-memory cache
+ await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
self.client.close()
@@ -246,38 +361,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
+ if params and params.get("mode") == "keyword":
+ # Check if this is inline Milvus (Milvus-Lite)
+ if hasattr(self.config, "db_path"):
+ raise NotImplementedError(
+ "Keyword search is not supported in Milvus-Lite. "
+ "Please use a remote Milvus server for keyword search functionality."
+ )
+
return await index.query_chunks(query, params)
- async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
- """Save vector store metadata to persistent storage."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.set(key=key, value=json.dumps(store_info))
- self.openai_vector_stores[store_id] = store_info
-
- async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
- """Update vector store metadata in persistent storage."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.set(key=key, value=json.dumps(store_info))
- self.openai_vector_stores[store_id] = store_info
-
- async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
- """Delete vector store metadata from persistent storage."""
- assert self.kvstore is not None
- key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
- await self.kvstore.delete(key)
- if store_id in self.openai_vector_stores:
- del self.openai_vector_stores[store_id]
-
- async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
- """Load all vector store metadata from persistent storage."""
- assert self.kvstore is not None
- start_key = OPENAI_VECTOR_STORES_PREFIX
- end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
- stored = await self.kvstore.values_in_range(start_key, end_key)
- return {json.loads(s)["id"]: json.loads(s) for s in stored}
-
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py
index 92908aa8a..334cbe5be 100644
--- a/llama_stack/providers/remote/vector_io/pgvector/config.py
+++ b/llama_stack/providers/remote/vector_io/pgvector/config.py
@@ -8,6 +8,10 @@ from typing import Any
from pydantic import BaseModel, Field
+from llama_stack.providers.utils.kvstore.config import (
+ KVStoreConfig,
+ SqliteKVStoreConfig,
+)
from llama_stack.schema_utils import json_schema_type
@@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel):
db: str | None = Field(default="postgres")
user: str | None = Field(default="postgres")
password: str | None = Field(default="mysecretpassword")
+ kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
@classmethod
def sample_run_config(
cls,
+ __distro_dir__: str,
host: str = "${env.PGVECTOR_HOST:=localhost}",
port: int = "${env.PGVECTOR_PORT:=5432}",
db: str = "${env.PGVECTOR_DB}",
@@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel):
password: str = "${env.PGVECTOR_PASSWORD}",
**kwargs: Any,
) -> dict[str, Any]:
- return {"host": host, "port": port, "db": db, "user": user, "password": password}
+ return {
+ "host": host,
+ "port": port,
+ "db": db,
+ "user": user,
+ "password": password,
+ "kvstore": SqliteKVStoreConfig.sample_run_config(
+ __distro_dir__=__distro_dir__,
+ db_name="pgvector_registry.db",
+ ),
+ }
diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
index c3cdef9b8..1bf3eedf8 100644
--- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
+++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py
@@ -13,24 +13,18 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
+from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
QueryChunksResponse,
- SearchRankingOptions,
VectorIO,
- VectorStoreChunkingStrategy,
- VectorStoreDeleteResponse,
- VectorStoreFileContentsResponse,
- VectorStoreFileObject,
- VectorStoreFileStatus,
- VectorStoreListFilesResponse,
- VectorStoreListResponse,
- VectorStoreObject,
- VectorStoreSearchResponsePage,
)
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
+from llama_stack.providers.utils.kvstore import kvstore_impl
+from llama_stack.providers.utils.kvstore.api import KVStore
+from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__)
+VERSION = "v3"
+VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
+VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
+OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
+OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
+OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
+
def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@@ -69,7 +70,7 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex):
- def __init__(self, vector_db: VectorDB, dimension: int, conn):
+ def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
self.conn = conn
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores
@@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex):
# when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = vector_db.identifier.replace("-", "_")
self.table_name = f"vector_store_{sanitized_identifier}"
+ self.kvstore = kvstore
cur.execute(
f"""
@@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
-class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
- def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
+class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
+ def __init__(
+ self,
+ config: PGVectorVectorIOConfig,
+ inference_api: Api.inference,
+ files_api: Files | None = None,
+ ) -> None:
self.config = config
self.inference_api = inference_api
self.conn = None
self.cache = {}
+ self.files_api = files_api
+ self.kvstore: KVStore | None = None
+ self.vector_db_store = None
+ self.openai_vector_store: dict[str, dict[str, Any]] = {}
+ self.metadatadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
+ self.kvstore = await kvstore_impl(self.config.kvstore)
+ await self.initialize_openai_vector_stores()
+
try:
self.conn = psycopg2.connect(
host=self.config.host,
@@ -201,14 +216,31 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
log.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None:
+ # Persist vector DB metadata in the KV store
+ assert self.kvstore is not None
+ key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
+ await self.kvstore.set(key=key, value=vector_db.model_dump_json())
+
+ # Upsert model metadata in Postgres
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
- index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
- self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
+ # Create and cache the PGVector index table for the vector DB
+ index = VectorDBWithIndex(
+ vector_db,
+ index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
+ inference_api=self.inference_api,
+ )
+ self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None:
- await self.cache[vector_db_id].index.delete()
- del self.cache[vector_db_id]
+ # Remove provider index and cache
+ if vector_db_id in self.cache:
+ await self.cache[vector_db_id].index.delete()
+ del self.cache[vector_db_id]
+
+ # Delete vector DB metadata from KV store
+ assert self.kvstore is not None
+ await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_chunks(
self,
@@ -237,107 +269,20 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id]
- async def openai_create_vector_store(
- self,
- name: str,
- file_ids: list[str] | None = None,
- expires_after: dict[str, Any] | None = None,
- chunking_strategy: dict[str, Any] | None = None,
- metadata: dict[str, Any] | None = None,
- embedding_model: str | None = None,
- embedding_dimension: int | None = 384,
- provider_id: str | None = None,
- provider_vector_db_id: str | None = None,
- ) -> VectorStoreObject:
+ # OpenAI Vector Stores File operations are not supported in PGVector
+ async def _save_openai_vector_store_file(
+ self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
+ ) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
- async def openai_list_vector_stores(
- self,
- limit: int | None = 20,
- order: str | None = "desc",
- after: str | None = None,
- before: str | None = None,
- ) -> VectorStoreListResponse:
+ async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
- async def openai_retrieve_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreObject:
+ async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
- async def openai_update_vector_store(
- self,
- vector_store_id: str,
- name: str | None = None,
- expires_after: dict[str, Any] | None = None,
- metadata: dict[str, Any] | None = None,
- ) -> VectorStoreObject:
+ async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
- async def openai_delete_vector_store(
- self,
- vector_store_id: str,
- ) -> VectorStoreDeleteResponse:
- raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
-
- async def openai_search_vector_store(
- self,
- vector_store_id: str,
- query: str | list[str],
- filters: dict[str, Any] | None = None,
- max_num_results: int | None = 10,
- ranking_options: SearchRankingOptions | None = None,
- rewrite_query: bool | None = False,
- search_mode: str | None = "vector",
- ) -> VectorStoreSearchResponsePage:
- raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
-
- async def openai_attach_file_to_vector_store(
- self,
- vector_store_id: str,
- file_id: str,
- attributes: dict[str, Any] | None = None,
- chunking_strategy: VectorStoreChunkingStrategy | None = None,
- ) -> VectorStoreFileObject:
- raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
-
- async def openai_list_files_in_vector_store(
- self,
- vector_store_id: str,
- limit: int | None = 20,
- order: str | None = "desc",
- after: str | None = None,
- before: str | None = None,
- filter: VectorStoreFileStatus | None = None,
- ) -> VectorStoreListFilesResponse:
- raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
-
- async def openai_retrieve_vector_store_file(
- self,
- vector_store_id: str,
- file_id: str,
- ) -> VectorStoreFileObject:
- raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
-
- async def openai_retrieve_vector_store_file_contents(
- self,
- vector_store_id: str,
- file_id: str,
- ) -> VectorStoreFileContentsResponse:
- raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
-
- async def openai_update_vector_store_file(
- self,
- vector_store_id: str,
- file_id: str,
- attributes: dict[str, Any] | None = None,
- ) -> VectorStoreFileObject:
- raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
-
- async def openai_delete_vector_store_file(
- self,
- vector_store_id: str,
- file_id: str,
- ) -> VectorStoreFileObject:
+ async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
diff --git a/llama_stack/providers/remote/vector_io/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py
index a8c6e3e2c..4283b8d3b 100644
--- a/llama_stack/providers/remote/vector_io/weaviate/config.py
+++ b/llama_stack/providers/remote/vector_io/weaviate/config.py
@@ -6,15 +6,26 @@
from typing import Any
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
+
+from llama_stack.providers.utils.kvstore.config import (
+ KVStoreConfig,
+ SqliteKVStoreConfig,
+)
class WeaviateRequestProviderData(BaseModel):
weaviate_api_key: str
weaviate_cluster_url: str
+ kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
class WeaviateVectorIOConfig(BaseModel):
@classmethod
- def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
- return {}
+ def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
+ return {
+ "kvstore": SqliteKVStoreConfig.sample_run_config(
+ __distro_dir__=__distro_dir__,
+ db_name="weaviate_registry.db",
+ ),
+ }
diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
index c63dd70c6..35bb40454 100644
--- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
+++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
@@ -14,10 +14,13 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent
+from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
+from llama_stack.providers.utils.kvstore import kvstore_impl
+from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
@@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
log = logging.getLogger(__name__)
+VERSION = "v3"
+VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
+VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
+OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
+OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
+OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
+
class WeaviateIndex(EmbeddingIndex):
- def __init__(self, client: weaviate.Client, collection_name: str):
+ def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None):
self.client = client
self.collection_name = collection_name
+ self.kvstore = kvstore
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
@@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter(
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
- def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
+ def __init__(
+ self,
+ config: WeaviateVectorIOConfig,
+ inference_api: Api.inference,
+ files_api: Files | None,
+ ) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
+ self.files_api = files_api
+ self.kvstore: KVStore | None = None
+ self.vector_db_store = None
+ self.openai_vector_stores: dict[str, dict[str, Any]] = {}
+ self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.Client:
provider_data = self.get_request_provider_data()
@@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter(
return client
async def initialize(self) -> None:
- pass
+ """Set up KV store and load existing vector DBs and OpenAI vector stores."""
+ # Initialize KV store for metadata
+ self.kvstore = await kvstore_impl(self.config.kvstore)
+
+ # Load existing vector DB definitions
+ start_key = VECTOR_DBS_PREFIX
+ end_key = f"{VECTOR_DBS_PREFIX}\xff"
+ stored = await self.kvstore.values_in_range(start_key, end_key)
+ for raw in stored:
+ vector_db = VectorDB.model_validate_json(raw)
+ client = self._get_client()
+ idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
+ self.cache[vector_db.identifier] = VectorDBWithIndex(
+ vector_db=vector_db,
+ index=idx,
+ inference_api=self.inference_api,
+ )
+
+ # Load OpenAI vector stores metadata into cache
+ await self.initialize_openai_vector_stores()
async def shutdown(self) -> None:
for client in self.client_cache.values():
@@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter(
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)
+
+ # OpenAI Vector Stores File operations are not supported in Weaviate
+ async def _save_openai_vector_store_file(
+ self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
+ ) -> None:
+ raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
+
+ async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
+ raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
+
+ async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
+ raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
+
+ async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
+ raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
+
+ async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
+ raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
index 7c97ff7f6..27bb1c997 100644
--- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
+++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py
@@ -5,6 +5,7 @@
# the root directory of this source tree.
import asyncio
+import json
import logging
import mimetypes
import time
@@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse,
VectorStoreSearchResponsePage,
)
+from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
logger = logging.getLogger(__name__)
@@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
# These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]]
files_api: Files | None
+ # KV store for persisting OpenAI vector store metadata
+ kvstore: KVStore | None
- @abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
- pass
+ assert self.kvstore is not None
+ key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
+ await self.kvstore.set(key=key, value=json.dumps(store_info))
+ # update in-memory cache
+ self.openai_vector_stores[store_id] = store_info
- @abstractmethod
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
- pass
+ assert self.kvstore is not None
+ start_key = OPENAI_VECTOR_STORES_PREFIX
+ end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
+ stored_data = await self.kvstore.values_in_range(start_key, end_key)
+
+ stores: dict[str, dict[str, Any]] = {}
+ for item in stored_data:
+ info = json.loads(item)
+ stores[info["id"]] = info
+ return stores
- @abstractmethod
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
- pass
+ assert self.kvstore is not None
+ key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
+ await self.kvstore.set(key=key, value=json.dumps(store_info))
+ # update in-memory cache
+ self.openai_vector_stores[store_id] = store_info
- @abstractmethod
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
- pass
+ assert self.kvstore is not None
+ key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
+ await self.kvstore.delete(key)
+ # remove from in-memory cache
+ self.openai_vector_stores.pop(store_id, None)
@abstractmethod
async def _save_openai_vector_store_file(
@@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
"""Unregister a vector database (provider-specific implementation)."""
pass
+ async def initialize_openai_vector_stores(self) -> None:
+ """Load existing OpenAI vector stores into the in-memory cache."""
+ self.openai_vector_stores = await self._load_openai_vector_stores()
+
@abstractmethod
async def insert_chunks(
self,
diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py
index 4eccfb25c..e5c13aa74 100644
--- a/llama_stack/templates/nvidia/nvidia.py
+++ b/llama_stack/templates/nvidia/nvidia.py
@@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
),
]
- default_models = get_model_registry(available_models)
+ default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="nvidia",
distro_type="self_hosted",
diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py
index 942905dae..63a27e07f 100644
--- a/llama_stack/templates/open-benchmark/open_benchmark.py
+++ b/llama_stack/templates/open-benchmark/open_benchmark.py
@@ -128,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
+ f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
@@ -146,7 +147,8 @@ def get_distribution_template() -> DistributionTemplate:
),
]
- default_models = get_model_registry(available_models) + [
+ models, _ = get_model_registry(available_models)
+ default_models = models + [
ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq",
diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml
index 0b368ebc9..7d07cc4bf 100644
--- a/llama_stack/templates/open-benchmark/run.yaml
+++ b/llama_stack/templates/open-benchmark/run.yaml
@@ -54,6 +54,9 @@ providers:
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
+ kvstore:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml
index 888a2c3bf..8e20f5224 100644
--- a/llama_stack/templates/starter/run.yaml
+++ b/llama_stack/templates/starter/run.yaml
@@ -166,6 +166,9 @@ providers:
db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=}
+ kvstore:
+ type: sqlite
+ db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
@@ -1171,24 +1174,8 @@ models:
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding
shields:
-- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
- provider_id: llama-guard
- provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
-- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
- provider_id: llama-guard
- provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
-- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
- provider_id: llama-guard
- provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
-- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
- provider_id: llama-guard
- provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
-- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
- provider_id: llama-guard
- provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
-- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
- provider_id: llama-guard
- provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
+- shield_id: ${env.SAFETY_MODEL:=__disabled__}
+ provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
vector_dbs: []
datasets: []
scoring_fns: []
diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py
index 6b8aa8974..f6ca73028 100644
--- a/llama_stack/templates/starter/starter.py
+++ b/llama_stack/templates/starter/starter.py
@@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import (
ModelInput,
Provider,
ProviderSpec,
- ShieldInput,
ToolGroupInput,
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
@@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.anthropic.models import (
- SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.bedrock.models import (
- SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.cerebras.models import (
- SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.databricks.databricks import (
- SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.fireworks.models import (
- SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.gemini.models import (
- SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.groq.models import (
- SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.nvidia.models import (
- SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.openai.models import (
- SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.runpod.runpod import (
- SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.sambanova.models import (
- SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
)
-from llama_stack.providers.remote.inference.together.models import (
- SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
-)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig,
@@ -111,6 +74,7 @@ from llama_stack.templates.template import (
DistributionTemplate,
RunConfigSettings,
get_model_registry,
+ get_shield_registry,
)
@@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
safety_model_entries_map = {
- "openai": OPENAI_SAFETY_MODELS_ENTRIES,
- "fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
- "together": TOGETHER_SAFETY_MODELS_ENTRIES,
- "anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
- "gemini": GEMINI_SAFETY_MODELS_ENTRIES,
- "groq": GROQ_SAFETY_MODELS_ENTRIES,
- "sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
- "cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
- "bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
- "databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
- "nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
- "runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
- }
-
- # Special handling for providers with dynamic model entries
- if provider_type == "ollama":
- return [
+ "ollama": [
ProviderModelEntry(
- provider_model_id="llama-guard3:1b",
+ provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
- ]
+ ],
+ }
return safety_model_entries_map.get(provider_type, [])
@@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
# build a list of shields for all possible providers
-def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
- shields = []
+def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
+ available_models = {}
for provider in providers:
provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0:
continue
- if provider.provider_id:
- shield_id = provider.provider_id
- else:
- raise ValueError(f"Provider {provider.provider_type} has no provider_id")
- for safety_model_entry in safety_model_entries:
- print(f"provider.provider_id: {provider.provider_id}")
- print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
- shields.append(
- ShieldInput(
- provider_id="llama-guard",
- shield_id=shield_id,
- provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
- )
- )
- return shields
+
+ env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
+ provider_id = f"${{env.{env_var}:=__disabled__}}"
+
+ available_models[provider_id] = safety_model_entries
+
+ return available_models
def get_distribution_template() -> DistributionTemplate:
@@ -300,6 +241,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config(
+ f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}",
@@ -307,8 +249,6 @@ def get_distribution_template() -> DistributionTemplate:
),
]
- shields = get_shields_for_providers(remote_inference_providers)
-
providers = {
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ([p.provider_type for p in vector_io_providers]),
@@ -361,7 +301,10 @@ def get_distribution_template() -> DistributionTemplate:
},
)
- default_models = get_model_registry(available_models)
+ default_models, ids_conflict_in_models = get_model_registry(available_models)
+
+ available_safety_models = get_safety_models_for_providers(remote_inference_providers)
+ shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
return DistributionTemplate(
name=name,
diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py
index dceb13c8b..fb2528873 100644
--- a/llama_stack/templates/template.py
+++ b/llama_stack/templates/template.py
@@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]],
-) -> list[ModelInput]:
+) -> tuple[list[ModelInput], bool]:
models = []
# check for conflicts in model ids
@@ -74,7 +74,50 @@ def get_model_registry(
metadata=entry.metadata,
)
)
- return models
+ return models, ids_conflict
+
+
+def get_shield_registry(
+ available_safety_models: dict[str, list[ProviderModelEntry]],
+ ids_conflict_in_models: bool,
+) -> list[ShieldInput]:
+ shields = []
+
+ # check for conflicts in shield ids
+ all_ids = set()
+ ids_conflict = False
+
+ for _, entries in available_safety_models.items():
+ for entry in entries:
+ ids = [entry.provider_model_id] + entry.aliases
+ for model_id in ids:
+ if model_id in all_ids:
+ ids_conflict = True
+ rich.print(
+ f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
+ )
+ break
+ all_ids.update(ids)
+ if ids_conflict:
+ break
+ if ids_conflict:
+ break
+
+ for provider_id, entries in available_safety_models.items():
+ for entry in entries:
+ ids = [entry.provider_model_id] + entry.aliases
+ for model_id in ids:
+ identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
+ shields.append(
+ ShieldInput(
+ shield_id=identifier,
+ provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
+ if ids_conflict_in_models
+ else entry.provider_model_id,
+ )
+ )
+
+ return shields
class DefaultModel(BaseModel):
diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py
index 7fa3a55e5..ea185f05d 100644
--- a/llama_stack/templates/watsonx/watsonx.py
+++ b/llama_stack/templates/watsonx/watsonx.py
@@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
},
)
- default_models = get_model_registry(available_models)
+ default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="watsonx",
distro_type="remote_hosted",
diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py
index 66c9ab829..05549cf18 100644
--- a/tests/integration/agents/test_agents.py
+++ b/tests/integration/agents/test_agents.py
@@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
return agent_config
+@pytest.fixture(scope="session")
+def agent_config_without_safety(text_model_id):
+ agent_config = dict(
+ model=text_model_id,
+ instructions="You are a helpful assistant",
+ sampling_params={
+ "strategy": {
+ "type": "top_p",
+ "temperature": 0.0001,
+ "top_p": 0.9,
+ },
+ },
+ tools=[],
+ enable_session_persistence=False,
+ )
+ return agent_config
+
+
def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}")
@@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower()
-def test_rag_agent_with_attachments(llama_stack_client, agent_config):
+def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
urls = ["llama3.rst", "lora_finetune.rst"]
documents = [
# passign as url
@@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
metadata={},
),
]
- rag_agent = Agent(llama_stack_client, **agent_config)
+ rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
- user_prompts = [
- (
- "Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
- "grouped",
- ),
- ]
user_prompts = [
(
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
@@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
assert "lora" in response.output_message.content.lower()
-@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
-def test_rag_and_code_agent(llama_stack_client, agent_config):
- if "llama-4" in agent_config["model"].lower():
- pytest.xfail("Not working for llama4")
-
- documents = []
- documents.append(
- Document(
- document_id="nba_wiki",
- content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
- metadata={},
- )
- )
- documents.append(
- Document(
- document_id="perplexity_wiki",
- content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
-
- Srinivas, the CEO, worked at OpenAI as an AI researcher.
- Konwinski was among the founding team at Databricks.
- Yarats, the CTO, was an AI research scientist at Meta.
- Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
- metadata={},
- )
- )
- vector_db_id = f"test-vector-db-{uuid4()}"
- llama_stack_client.vector_dbs.register(
- vector_db_id=vector_db_id,
- embedding_model="all-MiniLM-L6-v2",
- embedding_dimension=384,
- )
- llama_stack_client.tool_runtime.rag_tool.insert(
- documents=documents,
- vector_db_id=vector_db_id,
- chunk_size_in_tokens=128,
- )
- agent_config = {
- **agent_config,
- "tools": [
- dict(
- name="builtin::rag/knowledge_search",
- args={"vector_db_ids": [vector_db_id]},
- ),
- "builtin::code_interpreter",
- ],
- }
- agent = Agent(llama_stack_client, **agent_config)
- user_prompts = [
- (
- "when was Perplexity the company founded?",
- [],
- "knowledge_search",
- "2022",
- ),
- (
- "when was the nba created?",
- [],
- "knowledge_search",
- "1949",
- ),
- ]
-
- for prompt, docs, tool_name, expected_kw in user_prompts:
- session_id = agent.create_session(f"test-session-{uuid4()}")
- response = agent.create_turn(
- messages=[{"role": "user", "content": prompt}],
- session_id=session_id,
- documents=docs,
- stream=False,
- )
- tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
- assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
- if expected_kw:
- assert expected_kw in response.output_message.content.lower()
-
-
@pytest.mark.parametrize(
"client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py
new file mode 100644
index 000000000..2f212e374
--- /dev/null
+++ b/tests/unit/providers/vector_io/remote/test_milvus.py
@@ -0,0 +1,191 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the terms described in the LICENSE file in
+# the root directory of this source tree.
+
+from unittest.mock import MagicMock, patch
+
+import numpy as np
+import pytest
+import pytest_asyncio
+
+from llama_stack.apis.vector_io import QueryChunksResponse
+
+# Mock the entire pymilvus module
+pymilvus_mock = MagicMock()
+pymilvus_mock.DataType = MagicMock()
+pymilvus_mock.MilvusClient = MagicMock
+
+# Apply the mock before importing MilvusIndex
+with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
+ from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
+
+# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
+# tests which are specific to this class. More general (API-level) tests should be placed in
+# tests/integration/vector_io/
+#
+# How to run this test:
+#
+# pytest tests/unit/providers/vector_io/test_milvus.py \
+# -v -s --tb=short --disable-warnings --asyncio-mode=auto
+
+MILVUS_PROVIDER = "milvus"
+
+
+@pytest_asyncio.fixture
+async def mock_milvus_client() -> MagicMock:
+ """Create a mock Milvus client with common method behaviors."""
+ client = MagicMock()
+
+ # Mock collection operations
+ client.has_collection.return_value = False # Initially no collection
+ client.create_collection.return_value = None
+ client.drop_collection.return_value = None
+
+ # Mock insert operation
+ client.insert.return_value = {"insert_count": 10}
+
+ # Mock search operation - return mock results (data should be dict, not JSON string)
+ client.search.return_value = [
+ [
+ {
+ "id": 0,
+ "distance": 0.1,
+ "entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
+ },
+ {
+ "id": 1,
+ "distance": 0.2,
+ "entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
+ },
+ ]
+ ]
+
+ # Mock query operation for keyword search (data should be dict, not JSON string)
+ client.query.return_value = [
+ {
+ "chunk_id": "chunk1",
+ "chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
+ "score": 0.9,
+ },
+ {
+ "chunk_id": "chunk2",
+ "chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
+ "score": 0.8,
+ },
+ {
+ "chunk_id": "chunk3",
+ "chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
+ "score": 0.7,
+ },
+ ]
+
+ return client
+
+
+@pytest_asyncio.fixture
+async def milvus_index(mock_milvus_client):
+ """Create a MilvusIndex with mocked client."""
+ index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
+ yield index
+ # No real cleanup needed since we're using mocks
+
+
+@pytest.mark.asyncio
+async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
+ # Setup: collection doesn't exist initially, then exists after creation
+ mock_milvus_client.has_collection.side_effect = [False, True]
+
+ await milvus_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Verify collection was created and data was inserted
+ mock_milvus_client.create_collection.assert_called_once()
+ mock_milvus_client.insert.assert_called_once()
+
+ # Verify the insert call had the right number of chunks
+ insert_call = mock_milvus_client.insert.call_args
+ assert len(insert_call[1]["data"]) == len(sample_chunks)
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_vector(
+ milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
+):
+ # Setup: Add chunks first
+ mock_milvus_client.has_collection.return_value = True
+ await milvus_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Test vector search
+ query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
+ response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
+
+ assert isinstance(response, QueryChunksResponse)
+ assert len(response.chunks) == 2
+ mock_milvus_client.search.assert_called_once()
+
+
+@pytest.mark.asyncio
+async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
+ mock_milvus_client.has_collection.return_value = True
+ await milvus_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Test keyword search
+ query_string = "Sentence 5"
+ response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
+
+ assert isinstance(response, QueryChunksResponse)
+ assert len(response.chunks) == 2
+
+
+@pytest.mark.asyncio
+async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
+ """Test that when BM25 search fails, the system falls back to simple text search."""
+ mock_milvus_client.has_collection.return_value = True
+ await milvus_index.add_chunks(sample_chunks, sample_embeddings)
+
+ # Force BM25 search to fail
+ mock_milvus_client.search.side_effect = Exception("BM25 search not available")
+
+ # Mock simple text search results
+ mock_milvus_client.query.return_value = [
+ {
+ "chunk_id": "chunk1",
+ "chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
+ },
+ {
+ "chunk_id": "chunk2",
+ "chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
+ },
+ ]
+
+ # Test keyword search that should fall back to simple text search
+ query_string = "Python"
+ response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
+
+ # Verify response structure
+ assert isinstance(response, QueryChunksResponse)
+ assert len(response.chunks) > 0, "Fallback search should return results"
+
+ # Verify that simple text search was used (query method called instead of search)
+ mock_milvus_client.query.assert_called_once()
+ mock_milvus_client.search.assert_called_once() # Called once but failed
+
+ # Verify the query uses parameterized filter with filter_params
+ query_call_args = mock_milvus_client.query.call_args
+ assert "filter" in query_call_args[1], "Query should include filter for text search"
+ assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
+ assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
+
+ # Verify all returned chunks have score 1.0 (simple binary scoring)
+ assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
+
+
+@pytest.mark.asyncio
+async def test_delete_collection(milvus_index, mock_milvus_client):
+ # Test collection deletion
+ mock_milvus_client.has_collection.return_value = True
+
+ await milvus_index.delete()
+
+ mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)