feat: Enable setting a default embedding model in the stack (#3803)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 0s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Python Package Build Test / build (3.12) (push) Failing after 1s
Python Package Build Test / build (3.13) (push) Failing after 1s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 3s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Unit Tests / unit-tests (3.12) (push) Failing after 4s
Test External API and Providers / test-external (venv) (push) Failing after 4s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
API Conformance Tests / check-schema-compatibility (push) Successful in 11s
UI Tests / ui-tests (22) (push) Successful in 40s
Pre-commit / pre-commit (push) Successful in 1m28s

# What does this PR do?

Enables automatic embedding model detection for vector stores and by
using a `default_configured` boolean that can be defined in the
`run.yaml`.

<!-- If resolving an issue, uncomment and update the line below -->
<!-- Closes #[issue-number] -->

## Test Plan
- Unit tests
- Integration tests
- Simple example below:

Spin up the stack:
```bash
uv run llama stack build --distro starter --image-type venv --run
```
Then test with OpenAI's client:
```python
from openai import OpenAI
client = OpenAI(base_url="http://localhost:8321/v1/", api_key="none")
vs = client.vector_stores.create()
```
Previously you needed:

```python
vs = client.vector_stores.create(
    extra_body={
        "embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
        "embedding_dimension": 384,
    }
)
```

The `extra_body` is now unnecessary.

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Arceo 2025-10-14 21:25:13 -04:00 committed by GitHub
parent d875e427bf
commit ef4bc70bbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
29 changed files with 553 additions and 403 deletions

View file

@ -10,358 +10,111 @@ import TabItem from '@theme/TabItem';
# Retrieval Augmented Generation (RAG)
RAG enables your applications to reference and recall information from previous interactions or external documents.
RAG enables your applications to reference and recall information from external documents. Llama Stack makes Agentic RAG available through OpenAI's Responses API.
## Quick Start
### 1. Start the Server
In one terminal, start the Llama Stack server:
```bash
uv run llama stack build --distro starter --image-type venv --run
```
### 2. Connect with OpenAI Client
In another terminal, use the standard OpenAI client with the Responses API:
```python
import io, requests
from openai import OpenAI
url = "https://www.paulgraham.com/greatwork.html"
client = OpenAI(base_url="http://localhost:8321/v1/", api_key="none")
# Create vector store - auto-detects default embedding model
vs = client.vector_stores.create()
response = requests.get(url)
pseudo_file = io.BytesIO(str(response.content).encode('utf-8'))
file_id = client.files.create(file=(url, pseudo_file, "text/html"), purpose="assistants").id
client.vector_stores.files.create(vector_store_id=vs.id, file_id=file_id)
resp = client.responses.create(
model="gpt-4o",
input="How do you do great work? Use the existing knowledge_search tool.",
tools=[{"type": "file_search", "vector_store_ids": [vs.id]}],
include=["file_search_call.results"],
)
print(resp.output[-1].content[-1].text)
```
Which should give output like:
```
Doing great work is about more than just hard work and ambition; it involves combining several elements:
1. **Pursue What Excites You**: Engage in projects that are both ambitious and exciting to you. It's important to work on something you have a natural aptitude for and a deep interest in.
2. **Explore and Discover**: Great work often feels like a blend of discovery and creation. Focus on seeing possibilities and let ideas take their natural shape, rather than just executing a plan.
3. **Be Bold Yet Flexible**: Take bold steps in your work without over-planning. An adaptable approach that evolves with new ideas can often lead to breakthroughs.
4. **Work on Your Own Projects**: Develop a habit of working on projects of your own choosing, as these often lead to great achievements. These should be projects you find exciting and that challenge you intellectually.
5. **Be Earnest and Authentic**: Approach your work with earnestness and authenticity. Trying to impress others with affectation can be counterproductive, as genuine effort and intellectual honesty lead to better work outcomes.
6. **Build a Supportive Environment**: Work alongside great colleagues who inspire you and enhance your work. Surrounding yourself with motivating individuals creates a fertile environment for great work.
7. **Maintain High Morale**: High morale significantly impacts your ability to do great work. Stay optimistic and protect your mental well-being to maintain progress and momentum.
8. **Balance**: While hard work is essential, overworking can lead to diminishing returns. Balance periods of intensive work with rest to sustain productivity over time.
This approach shows that great work is less about following a strict formula and more about aligning your interests, ambition, and environment to foster creativity and innovation.
```
## Architecture Overview
Llama Stack organizes the APIs that enable RAG into three layers:
Llama Stack provides OpenAI-compatible RAG capabilities through:
1. **Lower-Level APIs**: Deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon)
2. **RAG Tool**: A first-class tool as part of the [Tools API](./tools) that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly
3. **Agents API**: The top-level [Agents API](./agent) that allows you to create agents that can use the tools to answer questions, perform tasks, and more
- **Vector Stores API**: OpenAI-compatible vector storage with automatic embedding model detection
- **Files API**: Document upload and processing using OpenAI's file format
- **Responses API**: Enhanced chat completions with agentic tool calling via file search
![RAG System Architecture](/img/rag.png)
## Configuring Default Embedding Models
The RAG system uses lower-level storage for different types of data:
- **Vector IO**: For semantic search and retrieval
- **Key-Value and Relational IO**: For structured data storage
To enable automatic vector store creation without specifying embedding models, configure a default embedding model in your run.yaml like so:
:::info[Future Storage Types]
We may add more storage types like Graph IO in the future.
:::
## Setting up Vector Databases
For this guide, we will use [Ollama](https://ollama.com/) as the inference provider. Ollama is an LLM runtime that allows you to run Llama models locally.
Here's how to set up a vector database for RAG:
```python
# Create HTTP client
import os
from llama_stack_client import LlamaStackClient
client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
# Register a vector database
vector_db_id = "my_documents"
response = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="nomic-embed-text-v1.5",
embedding_dimension=768,
provider_id="faiss",
)
```yaml
models:
- model_id: nomic-ai/nomic-embed-text-v1.5
provider_id: inline::sentence-transformers
metadata:
embedding_dimension: 768
default_configured: true
```
## Document Ingestion
With this configuration:
- `client.vector_stores.create()` works without requiring embedding model parameters
- The system automatically uses the default model and its embedding dimension for any newly created vector store
- Only one model can be marked as `default_configured: true`
You can ingest documents into the vector database using two methods: directly inserting pre-chunked documents or using the RAG Tool.
## Vector Store Operations
### Direct Document Insertion
### Creating Vector Stores
<Tabs>
<TabItem value="basic" label="Basic Insertion">
You can create vector stores with automatic or explicit embedding model selection:
```python
# You can insert a pre-chunked document directly into the vector db
chunks = [
{
"content": "Your document text here",
"mime_type": "text/plain",
"metadata": {
"document_id": "doc1",
"author": "Jane Doe",
},
},
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
```
# Automatic - uses default configured embedding model
vs = client.vector_stores.create()
</TabItem>
<TabItem value="embeddings" label="With Precomputed Embeddings">
If you decide to precompute embeddings for your documents, you can insert them directly into the vector database by including the embedding vectors in the chunk data. This is useful if you have a separate embedding service or if you want to customize the ingestion process.
```python
chunks_with_embeddings = [
{
"content": "First chunk of text",
"mime_type": "text/plain",
"embedding": [0.1, 0.2, 0.3, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "introduction"},
},
{
"content": "Second chunk of text",
"mime_type": "text/plain",
"embedding": [0.2, 0.3, 0.4, ...], # Your precomputed embedding vector
"metadata": {"document_id": "doc1", "section": "methodology"},
},
]
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks_with_embeddings)
```
:::warning[Embedding Dimensions]
When providing precomputed embeddings, ensure the embedding dimension matches the `embedding_dimension` specified when registering the vector database.
:::
</TabItem>
</Tabs>
### Document Retrieval
You can query the vector database to retrieve documents based on their embeddings.
```python
# You can then query for these chunks
chunks_response = client.vector_io.query(
vector_db_id=vector_db_id,
query="What do you know about..."
# Explicit - specify embedding model when you need a specific one
vs = client.vector_stores.create(
extra_body={
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
"embedding_dimension": 768
}
)
```
## Using the RAG Tool
:::danger[Deprecation Notice]
The RAG Tool is being deprecated in favor of directly using the OpenAI-compatible Search API. We recommend migrating to the OpenAI APIs for better compatibility and future support.
:::
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the [appendix](#more-ragdocument-examples).
### OpenAI API Integration & Migration
The RAG tool has been updated to use OpenAI-compatible APIs. This provides several benefits:
- **Files API Integration**: Documents are now uploaded using OpenAI's file upload endpoints
- **Vector Stores API**: Vector storage operations use OpenAI's vector store format with configurable chunking strategies
- **Error Resilience**: When processing multiple documents, individual failures are logged but don't crash the operation. Failed documents are skipped while successful ones continue processing.
### Migration Path
We recommend migrating to the OpenAI-compatible Search API for:
1. **Better OpenAI Ecosystem Integration**: Direct compatibility with OpenAI tools and workflows including the Responses API
2. **Future-Proof**: Continued support and feature development
3. **Full OpenAI Compatibility**: Vector Stores, Files, and Search APIs are fully compatible with OpenAI's Responses API
The OpenAI APIs are used under the hood, so you can continue to use your existing RAG Tool code with minimal changes. However, we recommend updating your code to use the new OpenAI-compatible APIs for better long-term support. If any documents fail to process, they will be logged in the response but will not cause the entire operation to fail.
### RAG Tool Example
```python
from llama_stack_client import RAGDocument
urls = ["memory_optimizations.rst", "chat.rst", "llama3.rst"]
documents = [
RAGDocument(
document_id=f"num-{i}",
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
mime_type="text/plain",
metadata={},
)
for i, url in enumerate(urls)
]
client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=512,
)
# Query documents
results = client.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What do you know about...",
)
```
### Custom Context Configuration
You can configure how the RAG tool adds metadata to the context if you find it useful for your application:
```python
# Query documents with custom template
results = client.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id],
content="What do you know about...",
query_config={
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
},
)
```
## Building RAG-Enhanced Agents
One of the most powerful patterns is combining agents with RAG capabilities. Here's a complete example:
### Agent with Knowledge Search
```python
from llama_stack_client import Agent
# Create agent with memory
agent = Agent(
client,
model="meta-llama/Llama-3.3-70B-Instruct",
instructions="You are a helpful assistant",
tools=[
{
"name": "builtin::rag/knowledge_search",
"args": {
"vector_db_ids": [vector_db_id],
# Defaults
"query_config": {
"chunk_size_in_tokens": 512,
"chunk_overlap_in_tokens": 0,
"chunk_template": "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n",
},
},
}
],
)
session_id = agent.create_session("rag_session")
# Ask questions about documents in the vector db, and the agent will query the db to answer the question.
response = agent.create_turn(
messages=[{"role": "user", "content": "How to optimize memory in PyTorch?"}],
session_id=session_id,
)
```
:::tip[Agent Instructions]
The `instructions` field in the `AgentConfig` can be used to guide the agent's behavior. It is important to experiment with different instructions to see what works best for your use case.
:::
### Document-Aware Conversations
You can also pass documents along with the user's message and ask questions about them:
```python
# Initial document ingestion
response = agent.create_turn(
messages=[
{"role": "user", "content": "I am providing some documents for reference."}
],
documents=[
{
"content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
"mime_type": "text/plain",
}
],
session_id=session_id,
)
# Query with RAG
response = agent.create_turn(
messages=[{"role": "user", "content": "What are the key topics in the documents?"}],
session_id=session_id,
)
```
### Viewing Agent Responses
You can print the response with the following:
```python
from llama_stack_client import AgentEventLogger
for log in AgentEventLogger().log(response):
log.print()
```
## Vector Database Management
### Unregistering Vector DBs
If you need to clean up and unregister vector databases, you can do so as follows:
<Tabs>
<TabItem value="single" label="Single Database">
```python
# Unregister a specified vector database
vector_db_id = "my_vector_db_id"
print(f"Unregistering vector database: {vector_db_id}")
client.vector_dbs.unregister(vector_db_id=vector_db_id)
```
</TabItem>
<TabItem value="all" label="All Databases">
```python
# Unregister all vector databases
for vector_db_id in client.vector_dbs.list():
print(f"Unregistering vector database: {vector_db_id.identifier}")
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
```
</TabItem>
</Tabs>
## Best Practices
### 🎯 **Document Chunking**
- Use appropriate chunk sizes (512 tokens is often a good starting point)
- Consider overlap between chunks for better context preservation
- Experiment with different chunking strategies for your content type
### 🔍 **Embedding Strategy**
- Choose embedding models that match your domain
- Consider the trade-off between embedding dimension and performance
- Test different embedding models for your specific use case
### 📊 **Query Optimization**
- Use specific, well-formed queries for better retrieval
- Experiment with different search strategies
- Consider hybrid approaches (keyword + semantic search)
### 🛡️ **Error Handling**
- Implement proper error handling for failed document processing
- Monitor ingestion success rates
- Have fallback strategies for retrieval failures
## Appendix
### More RAGDocument Examples
Here are various ways to create RAGDocument objects for different content types:
```python
from llama_stack_client import RAGDocument
import base64
# File URI
RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"})
# Plain text
RAGDocument(document_id="num-1", content="plain text")
# Explicit text input
RAGDocument(
document_id="num-2",
content={
"type": "text",
"text": "plain text input",
}, # for inputs that should be treated as text explicitly
)
# Image from URL
RAGDocument(
document_id="num-3",
content={
"type": "image",
"image": {"url": {"uri": "https://mywebsite.com/image.jpg"}},
},
)
# Base64 encoded image
B64_ENCODED_IMAGE = base64.b64encode(
requests.get(
"https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png"
).content
)
RAGDocument(
document_id="num-4",
content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}},
)
```
For more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py).

View file

@ -496,12 +496,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
return await response.parse()
def _convert_body(self, func: Any, body: dict | None = None, exclude_params: set[str] | None = None) -> dict:
if not body:
return {}
body = body or {}
exclude_params = exclude_params or set()
sig = inspect.signature(func)
params_list = [p for p in sig.parameters.values() if p.name != "self"]
# Flatten if there's a single unwrapped body parameter (BaseModel or Annotated[BaseModel, Body(embed=False)])
if len(params_list) == 1:
param = params_list[0]
@ -530,11 +529,12 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
converted_body[param_name] = value
else:
converted_body[param_name] = convert_to_pydantic(param.annotation, value)
elif unwrapped_body_param and param.name == unwrapped_body_param.name:
# This is the unwrapped body param - construct it from remaining body keys
base_type = get_args(param.annotation)[0]
# Extract only the keys that aren't already used by other params
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
converted_body[param.name] = base_type(**remaining_keys)
# handle unwrapped body parameter after processing all named parameters
if unwrapped_body_param:
base_type = get_args(unwrapped_body_param.annotation)[0]
# extract only keys not already used by other params
remaining_keys = {k: v for k, v in body.items() if k not in converted_body}
converted_body[unwrapped_body_param.name] = base_type(**remaining_keys)
return converted_body

View file

@ -120,13 +120,7 @@ class VectorIORouter(VectorIO):
embedding_dimension = extra.get("embedding_dimension")
provider_id = extra.get("provider_id")
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
# Require explicit embedding model specification
if embedding_model is None:
raise ValueError("embedding_model is required in extra_body when creating a vector store")
if embedding_dimension is None:
if embedding_model is not None and embedding_dimension is None:
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
# Auto-select provider if not specified
@ -158,8 +152,10 @@ class VectorIORouter(VectorIO):
params.model_extra = {}
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
params.model_extra["provider_id"] = registered_vector_db.provider_id
params.model_extra["embedding_model"] = embedding_model
params.model_extra["embedding_dimension"] = embedding_dimension
if embedding_model is not None:
params.model_extra["embedding_model"] = embedding_model
if embedding_dimension is not None:
params.model_extra["embedding_dimension"] = embedding_dimension
return await provider.openai_create_vector_store(params)

View file

@ -98,6 +98,30 @@ REGISTRY_REFRESH_TASK = None
TEST_RECORDING_CONTEXT = None
async def validate_default_embedding_model(impls: dict[Api, Any]):
"""Validate that at most one embedding model is marked as default."""
if Api.models not in impls:
return
models_impl = impls[Api.models]
response = await models_impl.list_models()
models_list = response.data if hasattr(response, "data") else response
default_embedding_models = []
for model in models_list:
if model.model_type == "embedding" and model.metadata.get("default_configured") is True:
default_embedding_models.append(model.identifier)
if len(default_embedding_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_embedding_models}. "
"Only one embedding model can be marked as default."
)
if default_embedding_models:
logger.info(f"Default embedding model configured: {default_embedding_models[0]}")
async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
for rsrc, api, register_method, list_method in RESOURCES:
objects = getattr(run_config, rsrc)
@ -128,6 +152,8 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
f"{rsrc.capitalize()}: {obj.identifier} served by {obj.provider_id}",
)
await validate_default_embedding_model(impls)
class EnvVarError(Exception):
def __init__(self, var_name: str, path: str = ""):

View file

@ -59,6 +59,7 @@ class SentenceTransformersInferenceImpl(
provider_id=self.__provider_id__,
metadata={
"embedding_dimension": 768,
"default_configured": True,
},
model_type=ModelType.embedding,
),

View file

@ -16,6 +16,11 @@ async def get_provider_impl(config: ChromaVectorIOConfig, deps: dict[Api, Any]):
ChromaVectorIOAdapter,
)
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = ChromaVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -16,6 +16,11 @@ async def get_provider_impl(config: FaissVectorIOConfig, deps: dict[Api, Any]):
assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = FaissVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = FaissVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -17,6 +17,7 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -199,10 +200,17 @@ class FaissIndex(EmbeddingIndex):
class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: FaissVectorIOConfig, inference_api: Inference, files_api: Files | None) -> None:
def __init__(
self,
config: FaissVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.cache: dict[str, VectorDBWithIndex] = {}
async def initialize(self) -> None:

View file

@ -14,6 +14,11 @@ from .config import MilvusVectorIOConfig
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = MilvusVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -15,7 +15,11 @@ async def get_provider_impl(config: QdrantVectorIOConfig, deps: dict[Api, Any]):
from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter
assert isinstance(config, QdrantVectorIOConfig), f"Unexpected config type: {type(config)}"
files_api = deps.get(Api.files)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -15,6 +15,11 @@ async def get_provider_impl(config: SQLiteVectorIOConfig, deps: dict[Api, Any]):
from .sqlite_vec import SQLiteVecVectorIOAdapter
assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = SQLiteVecVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -17,6 +17,7 @@ from numpy.typing import NDArray
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -409,11 +410,19 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex).
"""
def __init__(self, config, inference_api: Inference, files_api: Files | None) -> None:
def __init__(
self,
config,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.cache: dict[str, VectorDBWithIndex] = {}
self.vector_db_store = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)

View file

@ -26,7 +26,7 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="Meta's reference implementation of a vector database.",
),
InlineProviderSpec(
@ -36,7 +36,7 @@ def available_providers() -> list[ProviderSpec]:
module="llama_stack.providers.inline.vector_io.faiss",
config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
@ -89,7 +89,7 @@ more details about Faiss in general.
module="llama_stack.providers.inline.vector_io.sqlite_vec",
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It
allows you to store and query vectors directly within an SQLite database.
@ -297,7 +297,7 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig",
deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
Please refer to the sqlite-vec provider documentation.
""",
@ -310,7 +310,7 @@ Please refer to the sqlite-vec provider documentation.
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -352,7 +352,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
module="llama_stack.providers.inline.vector_io.chroma",
config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -396,7 +396,7 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
@ -508,7 +508,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database.
@ -548,7 +548,7 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
module="llama_stack.providers.inline.vector_io.qdrant",
config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description=r"""
[Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
@ -601,7 +601,7 @@ See the [Qdrant documentation](https://qdrant.tech/documentation/) for more deta
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
Please refer to the inline provider documentation.
""",
@ -614,7 +614,7 @@ Please refer to the inline provider documentation.
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
@ -820,7 +820,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
optional_api_dependencies=[Api.files, Api.models],
description="""
Please refer to the remote provider documentation.
""",

View file

@ -12,6 +12,11 @@ from .config import ChromaVectorIOConfig
async def get_adapter_impl(config: ChromaVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .chroma import ChromaVectorIOAdapter
impl = ChromaVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files))
impl = ChromaVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -138,12 +138,14 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteChromaVectorIOConfig | InlineChromaVectorIOConfig,
inference_api: Api.inference,
models_apis: Api.models,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
log.info(f"Initializing ChromaVectorIOAdapter with url: {config}")
self.config = config
self.inference_api = inference_api
self.models_api = models_apis
self.client = None
self.cache = {}
self.vector_db_store = None

View file

@ -14,6 +14,11 @@ async def get_adapter_impl(config: MilvusVectorIOConfig, deps: dict[Api, Provide
assert isinstance(config, MilvusVectorIOConfig), f"Unexpected config type: {type(config)}"
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = MilvusVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -12,8 +12,9 @@ from numpy.typing import NDArray
from pymilvus import AnnSearchRequest, DataType, Function, FunctionType, MilvusClient, RRFRanker, WeightedRanker
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -307,6 +308,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self,
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -314,6 +316,7 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.cache = {}
self.client = None
self.inference_api = inference_api
self.models_api = models_api
self.vector_db_store = None
self.metadata_collection_name = "openai_vector_stores_metadata"

View file

@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .pgvector import PGVectorVectorIOAdapter
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps[Api.models], deps.get(Api.files, None))
await impl.initialize()
return impl

View file

@ -14,8 +14,9 @@ from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -23,7 +24,7 @@ from llama_stack.apis.vector_io import (
VectorIO,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
@ -342,12 +343,14 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
def __init__(
self,
config: PGVectorVectorIOConfig,
inference_api: Api.inference,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.conn = None
self.cache = {}
self.vector_db_store = None

View file

@ -12,7 +12,11 @@ from .config import QdrantVectorIOConfig
async def get_adapter_impl(config: QdrantVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorIOAdapter
files_api = deps.get(Api.files)
impl = QdrantVectorIOAdapter(config, deps[Api.inference], files_api)
impl = QdrantVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -15,7 +15,8 @@ from qdrant_client.models import PointStruct
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.inference import Inference, InterleavedContent
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -25,7 +26,7 @@ from llama_stack.apis.vector_io import (
VectorStoreFileObject,
)
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
@ -159,7 +160,8 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
def __init__(
self,
config: RemoteQdrantVectorIOConfig | InlineQdrantVectorIOConfig,
inference_api: Api.inference,
inference_api: Inference,
models_api: Models,
files_api: Files | None = None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
@ -167,6 +169,7 @@ class QdrantVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
self.client: AsyncQdrantClient = None
self.cache = {}
self.inference_api = inference_api
self.models_api = models_api
self.vector_db_store = None
self._qdrant_lock = asyncio.Lock()

View file

@ -12,6 +12,11 @@ from .config import WeaviateVectorIOConfig
async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]):
from .weaviate import WeaviateVectorIOAdapter
impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
impl = WeaviateVectorIOAdapter(
config,
deps[Api.inference],
deps[Api.models],
deps.get(Api.files),
)
await impl.initialize()
return impl

View file

@ -14,12 +14,14 @@ from weaviate.classes.query import Filter, HybridFusion
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files.files import Files
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.core.request_headers import NeedsRequestProviderData
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import (
@ -281,12 +283,14 @@ class WeaviateVectorIOAdapter(
def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Api.inference,
inference_api: Inference,
models_api: Models,
files_api: Files | None,
) -> None:
super().__init__(files_api=files_api, kvstore=None)
self.config = config
self.inference_api = inference_api
self.models_api = models_api
self.client_cache = {}
self.cache = {}
self.vector_db_store = None

View file

@ -17,6 +17,7 @@ from pydantic import TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.files import Files, OpenAIFileObject
from llama_stack.apis.models import Model, Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
@ -77,11 +78,14 @@ class OpenAIVectorStoreMixin(ABC):
# Implementing classes should call super().__init__() in their __init__ method
# to properly initialize the mixin attributes.
def __init__(self, files_api: Files | None = None, kvstore: KVStore | None = None):
def __init__(
self, files_api: Files | None = None, kvstore: KVStore | None = None, models_api: Models | None = None
):
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.openai_file_batches: dict[str, dict[str, Any]] = {}
self.files_api = files_api
self.kvstore = kvstore
self.models_api = models_api
self._last_file_batch_cleanup_time = 0
self._file_batch_tasks: dict[str, asyncio.Task[None]] = {}
@ -348,20 +352,32 @@ class OpenAIVectorStoreMixin(ABC):
"""Creates a vector store."""
created_at = int(time.time())
# Extract llama-stack-specific parameters from extra_body
extra = params.model_extra or {}
provider_vector_db_id = extra.get("provider_vector_db_id")
embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension", 768)
embedding_dimension = extra.get("embedding_dimension")
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
provider_id = extra.get("provider_id") or getattr(self, "__provider_id__", None)
# Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if embedding_model is None:
raise ValueError("Embedding model is required")
result = await self._get_default_embedding_model_and_dimension()
if result is None:
raise ValueError(
"embedding_model is required in extra_body when creating a vector store. "
"No default embedding model could be determined automatically."
)
embedding_model, embedding_dimension = result
elif embedding_dimension is None:
# Embedding model was provided but dimension wasn't, look it up
embedding_dimension = await self._get_embedding_dimension_for_model(embedding_model)
if embedding_dimension is None:
raise ValueError(
f"Could not determine embedding dimension for model '{embedding_model}'. "
"Please provide embedding_dimension in extra_body or ensure the model metadata contains embedding_dimension."
)
# Embedding dimension is required (defaulted to 768 if not provided)
if embedding_dimension is None:
raise ValueError("Embedding dimension is required")
@ -428,6 +444,85 @@ class OpenAIVectorStoreMixin(ABC):
store_info = self.openai_vector_stores[vector_db_id]
return VectorStoreObject.model_validate(store_info)
async def _get_embedding_models(self) -> list[Model]:
"""Get list of embedding models from the models API."""
if not self.models_api:
return []
models_response = await self.models_api.list_models()
models_list = models_response.data if hasattr(models_response, "data") else models_response
embedding_models = []
for model in models_list:
if not isinstance(model, Model):
logger.warning(f"Non-Model object found in models list: {type(model)} - {model}")
continue
if model.model_type == "embedding":
embedding_models.append(model)
return embedding_models
async def _get_embedding_dimension_for_model(self, model_id: str) -> int | None:
"""Get embedding dimension for a specific model by looking it up in the models API.
Args:
model_id: The identifier of the embedding model (supports both prefixed and non-prefixed)
Returns:
The embedding dimension for the model, or None if not found
"""
embedding_models = await self._get_embedding_models()
for model in embedding_models:
# Check for exact match first
if model.identifier == model_id:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is not None:
return int(embedding_dimension)
else:
logger.warning(f"Model {model_id} found but has no embedding_dimension in metadata")
return None
# Check for prefixed/unprefixed variations
# If model_id is unprefixed, check if it matches the resource_id
if model.provider_resource_id == model_id:
embedding_dimension = model.metadata.get("embedding_dimension")
if embedding_dimension is not None:
return int(embedding_dimension)
return None
async def _get_default_embedding_model_and_dimension(self) -> tuple[str, int] | None:
"""Get default embedding model from the models API.
Looks for embedding models marked with default_configured=True in metadata.
Returns None if no default embedding model is found.
Raises ValueError if multiple defaults are found.
"""
embedding_models = await self._get_embedding_models()
default_models = []
for model in embedding_models:
if model.metadata.get("default_configured") is True:
default_models.append(model.identifier)
if len(default_models) > 1:
raise ValueError(
f"Multiple embedding models marked as default_configured=True: {default_models}. "
"Only one embedding model can be marked as default."
)
if default_models:
model_id = default_models[0]
embedding_dimension = await self._get_embedding_dimension_for_model(model_id)
if embedding_dimension is None:
raise ValueError(f"Embedding model '{model_id}' has no embedding_dimension in metadata")
logger.info(f"Using default embedding model: {model_id} with dimension {embedding_dimension}")
return model_id, embedding_dimension
logger.info("DEBUG: No default embedding models found")
return None
async def openai_list_vector_stores(
self,
limit: int | None = 20,

View file

@ -159,6 +159,12 @@ def test_openai_create_vector_store(
assert hasattr(vector_store, "created_at")
def test_openai_create_vector_store_default(compat_client_with_empty_stores, client_with_models):
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
vector_store = compat_client_with_empty_stores.vector_stores.create()
assert vector_store.id
def test_openai_list_vector_stores(
compat_client_with_empty_stores, client_with_models, embedding_model_id, embedding_dimension
):

View file

@ -0,0 +1,93 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Unit tests for Stack validation functions.
"""
from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.core.stack import validate_default_embedding_model
from llama_stack.providers.datatypes import Api
class TestStackValidation:
"""Test Stack validation functions."""
@pytest.mark.parametrize(
"models,should_raise",
[
([], False), # No models
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
)
],
False,
), # Single default
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="emb2",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb2",
),
],
True,
), # Multiple defaults
(
[
Model(
identifier="emb1",
model_type=ModelType.embedding,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="emb1",
),
Model(
identifier="llm1",
model_type=ModelType.llm,
metadata={"default_configured": True},
provider_id="p",
provider_resource_id="llm1",
),
],
False,
), # Ignores non-embedding
],
)
async def test_validate_default_embedding_model(self, models, should_raise):
"""Test validation with various model configurations."""
mock_models_impl = AsyncMock()
mock_models_impl.list_models.return_value = models
impls = {Api.models: mock_models_impl}
if should_raise:
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
await validate_default_embedding_model(impls)
else:
await validate_default_embedding_model(impls)
async def test_validate_default_embedding_model_no_models_api(self):
"""Test validation when models API is not available."""
await validate_default_embedding_model({})

View file

@ -144,6 +144,7 @@ async def sqlite_vec_adapter(sqlite_vec_db_path, unique_kvstore_config, mock_inf
config=config,
inference_api=mock_inference_api,
files_api=None,
models_api=None,
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
@ -182,6 +183,7 @@ async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding
config=config,
inference_api=mock_inference_api,
files_api=None,
models_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(

View file

@ -11,6 +11,7 @@ import numpy as np
import pytest
from llama_stack.apis.files import Files
from llama_stack.apis.models import Models
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.datatypes import HealthStatus
@ -75,6 +76,12 @@ def mock_files_api():
return mock_api
@pytest.fixture
def mock_models_api():
mock_api = MagicMock(spec=Models)
return mock_api
@pytest.fixture
def faiss_config():
config = MagicMock(spec=FaissVectorIOConfig)
@ -110,7 +117,7 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[1] == sample_chunks[1]
async def test_health_success():
async def test_health_success(mock_models_api):
"""Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
@ -119,7 +126,9 @@ async def test_health_success():
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.return_value = MagicMock()
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
adapter = FaissVectorIOAdapter(
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
)
# Calling the health method directly
response = await adapter.health()
@ -133,7 +142,7 @@ async def test_health_success():
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
async def test_health_failure():
async def test_health_failure(mock_models_api):
"""Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing
config = MagicMock()
@ -143,7 +152,9 @@ async def test_health_failure():
with patch("llama_stack.providers.inline.vector_io.faiss.faiss.faiss.IndexFlatL2") as mock_index_flat:
mock_index_flat.side_effect = Exception("Test error")
adapter = FaissVectorIOAdapter(config=config, inference_api=inference_api, files_api=files_api)
adapter = FaissVectorIOAdapter(
config=config, inference_api=inference_api, models_api=mock_models_api, files_api=files_api
)
# Calling the health method directly
response = await adapter.health()

View file

@ -6,16 +6,18 @@
import json
import time
from unittest.mock import AsyncMock, patch
from unittest.mock import AsyncMock, Mock, patch
import numpy as np
import pytest
from llama_stack.apis.common.errors import VectorStoreNotFoundError
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import (
Chunk,
OpenAICreateVectorStoreFileBatchRequestWithExtraBody,
OpenAICreateVectorStoreRequestWithExtraBody,
QueryChunksResponse,
VectorStoreChunkingStrategyAuto,
VectorStoreFileObject,
@ -961,3 +963,93 @@ async def test_max_concurrent_files_per_batch(vector_io_adapter):
assert batch.status == "in_progress"
assert batch.file_counts.total == 8
assert batch.file_counts.in_progress == 8
async def test_get_default_embedding_model_success(vector_io_adapter):
"""Test successful default embedding model detection."""
# Mock models API with a default model
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="nomic-embed-text-v1.5",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={
"embedding_dimension": 768,
"default_configured": True,
},
)
]
)
)
vector_io_adapter.models_api = mock_models_api
result = await vector_io_adapter._get_default_embedding_model_and_dimension()
assert result is not None
model_id, dimension = result
assert model_id == "nomic-embed-text-v1.5"
assert dimension == 768
async def test_get_default_embedding_model_multiple_defaults_error(vector_io_adapter):
"""Test error when multiple models are marked as default."""
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="model1",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 768, "default_configured": True},
),
Model(
identifier="model2",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 512, "default_configured": True},
),
]
)
)
vector_io_adapter.models_api = mock_models_api
with pytest.raises(ValueError, match="Multiple embedding models marked as default_configured=True"):
await vector_io_adapter._get_default_embedding_model_and_dimension()
async def test_openai_create_vector_store_uses_default_model(vector_io_adapter):
"""Test that vector store creation uses default embedding model when none specified."""
# Mock models API and dependencies
mock_models_api = Mock()
mock_models_api.list_models = AsyncMock(
return_value=Mock(
data=[
Model(
identifier="default-model",
model_type=ModelType.embedding,
provider_id="test-provider",
metadata={"embedding_dimension": 512, "default_configured": True},
)
]
)
)
vector_io_adapter.models_api = mock_models_api
vector_io_adapter.register_vector_db = AsyncMock()
vector_io_adapter.__provider_id__ = "test-provider"
# Create vector store without specifying embedding model
params = OpenAICreateVectorStoreRequestWithExtraBody(name="test-store")
result = await vector_io_adapter.openai_create_vector_store(params)
# Verify the vector store was created with default model
assert result.name == "test-store"
vector_io_adapter.register_vector_db.assert_called_once()
call_args = vector_io_adapter.register_vector_db.call_args[0][0]
assert call_args.embedding_model == "default-model"
assert call_args.embedding_dimension == 512