Merge branch 'main' into vectordb_name

This commit is contained in:
Francisco Arceo 2025-07-14 23:42:45 -04:00 committed by GitHub
commit bd8c1cc071
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
52 changed files with 1363 additions and 921 deletions

30
.github/ISSUE_TEMPLATE/tech-debt.yml vendored Normal file
View file

@ -0,0 +1,30 @@
name: 🔧 Tech Debt
description: Something that is functional but should be improved or optimizied
labels: ["tech-debt"]
body:
- type: textarea
id: tech-debt-explanation
attributes:
label: 🤔 What is the technical debt you think should be addressed?
description: >
A clear and concise description of _what_ needs to be addressed - ensure you are describing
constitutes [technical debt](https://en.wikipedia.org/wiki/Technical_debt) and is not a bug
or feature request.
validations:
required: true
- type: textarea
id: tech-debt-motivation
attributes:
label: 💡 What is the benefit of addressing this technical debt?
description: >
A clear and concise description of _why_ this work is needed.
validations:
required: true
- type: textarea
id: other-thoughts
attributes:
label: Other thoughts
description: >
Any thoughts about how this may result in complexity in the codebase, or other trade-offs.

View file

@ -35,7 +35,7 @@ jobs:
- name: Install minikube - name: Install minikube
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'kubernetes' }}
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 uses: medyagh/setup-minikube@e3c7f79eb1e997eabccc536a6cf318a2b0fe19d9 # v0.0.20
- name: Start minikube - name: Start minikube
if: ${{ matrix.auth-provider == 'oauth2_token' }} if: ${{ matrix.auth-provider == 'oauth2_token' }}

View file

@ -89,7 +89,7 @@ jobs:
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/llama3.2:3b-instruct-fp16" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=all-MiniLM-L6-v2 \
--safety-shield=ollama \ --safety-shield=$SAFETY_MODEL \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -14795,7 +14795,8 @@
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
}, },
"mode": { "mode": {
"type": "string", "$ref": "#/components/schemas/RAGSearchMode",
"default": "vector",
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
}, },
"ranker": { "ranker": {
@ -14830,6 +14831,16 @@
} }
} }
}, },
"RAGSearchMode": {
"type": "string",
"enum": [
"vector",
"keyword",
"hybrid"
],
"title": "RAGSearchMode",
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
},
"RRFRanker": { "RRFRanker": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -10344,7 +10344,8 @@ components:
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n" {chunk.content}\nMetadata: {metadata}\n"
mode: mode:
type: string $ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >- description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector". "vector".
@ -10371,6 +10372,17 @@ components:
mapping: mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker: RRFRanker:
type: object type: object
properties: properties:

View file

@ -145,6 +145,10 @@ $ llama stack build --template starter
... ...
You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml` You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml`
``` ```
```{tip}
The generated `run.yaml` file is a starting point for your configuration. For comprehensive guidance on customizing it for your specific needs, infrastructure, and deployment scenarios, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md).
```
::: :::
:::{tab-item} Building from Scratch :::{tab-item} Building from Scratch

View file

@ -2,6 +2,10 @@
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution: The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
```{note}
The default `run.yaml` files generated by templates are starting points for your configuration. For guidance on customizing these files for your specific needs, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md).
```
```{dropdown} 👋 Click here for a Sample Configuration File ```{dropdown} 👋 Click here for a Sample Configuration File
```yaml ```yaml

View file

@ -0,0 +1,40 @@
# Customizing run.yaml Files
The `run.yaml` files generated by Llama Stack templates are **starting points** designed to be customized for your specific needs. They are not meant to be used as-is in production environments.
## Key Points
- **Templates are starting points**: Generated `run.yaml` files contain defaults for development/testing
- **Customization expected**: Update URLs, credentials, models, and settings for your environment
- **Version control separately**: Keep customized configs in your own repository
- **Environment-specific**: Create different configurations for dev, staging, production
## What You Can Customize
You can customize:
- **Provider endpoints**: Change `http://localhost:8000` to your actual servers
- **Swap providers**: Replace default providers (e.g., swap Tavily with Brave for search)
- **Storage paths**: Move from `/tmp/` to production directories
- **Authentication**: Add API keys, SSL, timeouts
- **Models**: Different model sizes for dev vs prod
- **Database settings**: Switch from SQLite to PostgreSQL
- **Tool configurations**: Add custom tools and integrations
## Best Practices
- Use environment variables for secrets and environment-specific values
- Create separate `run.yaml` files for different environments (dev, staging, prod)
- Document your changes with comments
- Test configurations before deployment
- Keep your customized configs in version control
Example structure:
```
your-project/
├── configs/
│ ├── dev-run.yaml
│ ├── prod-run.yaml
└── README.md
```
The goal is to take the generated template and adapt it to your specific infrastructure and operational needs.

View file

@ -9,6 +9,7 @@ This section provides an overview of the distributions available in Llama Stack.
importing_as_library importing_as_library
configuration configuration
customizing_run_yaml
list_of_distributions list_of_distributions
kubernetes_deployment kubernetes_deployment
building_distro building_distro

View file

@ -54,7 +54,7 @@ Llama Stack is a server that exposes multiple APIs, you connect with it using th
You can use Python to build and run the Llama Stack server, which is useful for testing and development. You can use Python to build and run the Llama Stack server, which is useful for testing and development.
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup, Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
which defines the providers and their settings. which defines the providers and their settings. The generated configuration serves as a starting point that you can [customize for your specific needs](../distributions/customizing_run_yaml.md).
Now let's build and run the Llama Stack config for Ollama. Now let's build and run the Llama Stack config for Ollama.
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables. We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
@ -77,7 +77,7 @@ ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template
You can use a container image to run the Llama Stack server. We provide several container images for the server You can use a container image to run the Llama Stack server. We provide several container images for the server
component that works with different inference providers out of the box. For this guide, we will use component that works with different inference providers out of the box. For this guide, we will use
`llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the `llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the
configurations, please check out [this guide](../references/index.md). configurations, please check out [this guide](../distributions/building_distro.md).
First lets setup some environment variables and create a local directory to mount into the containers file system. First lets setup some environment variables and create a local directory to mount into the containers file system.
```bash ```bash
export INFERENCE_MODEL="llama3.2:3b" export INFERENCE_MODEL="llama3.2:3b"

View file

@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server | | `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server | | `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server | | `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. | | `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider. > **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
```yaml ```yaml
uri: ${env.MILVUS_ENDPOINT} uri: ${env.MILVUS_ENDPOINT}
token: ${env.MILVUS_TOKEN} token: ${env.MILVUS_TOKEN}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
``` ```

View file

@ -40,6 +40,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
| `db` | `str \| None` | No | postgres | | | `db` | `str \| None` | No | postgres | |
| `user` | `str \| None` | No | postgres | | | `user` | `str \| None` | No | postgres | |
| `password` | `str \| None` | No | mysecretpassword | | | `password` | `str \| None` | No | mysecretpassword | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
## Sample Configuration ## Sample Configuration
@ -49,6 +50,9 @@ port: ${env.PGVECTOR_PORT:=5432}
db: ${env.PGVECTOR_DB} db: ${env.PGVECTOR_DB}
user: ${env.PGVECTOR_USER} user: ${env.PGVECTOR_USER}
password: ${env.PGVECTOR_PASSWORD} password: ${env.PGVECTOR_PASSWORD}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db
``` ```

View file

@ -36,7 +36,9 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
## Sample Configuration ## Sample Configuration
```yaml ```yaml
{} kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db
``` ```

View file

@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum):
custom = "custom" custom = "custom"
@json_schema_type
class RAGSearchMode(Enum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
- KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
"""
VECTOR = "vector"
KEYWORD = "keyword"
HYBRID = "hybrid"
@json_schema_type @json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel): class DefaultRAGQueryGeneratorConfig(BaseModel):
type: Literal["default"] = "default" type: Literal["default"] = "default"
@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel):
max_tokens_in_context: int = 4096 max_tokens_in_context: int = 4096
max_chunks: int = 5 max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None mode: RAGSearchMode | None = RAGSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template") @field_validator("chunk_template")

View file

@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
) )
self.cache[vector_db.identifier] = index self.cache[vector_db.identifier] = index
# Load existing OpenAI vector stores using the mixin method # Load existing OpenAI vector stores into the in-memory cache
self.openai_vector_stores = await self._load_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
# Cleanup if needed # Cleanup if needed
@ -261,42 +261,10 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from kvstore."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from kvstore."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
async def _save_openai_vector_store_file( async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None: ) -> None:
"""Save vector store file metadata to kvstore.""" """Save vector store file data to kvstore."""
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info)) await self.kvstore.set(key=key, value=json.dumps(file_info))
@ -324,7 +292,16 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.kvstore.set(key=key, value=json.dumps(file_info)) await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from kvstore.""" """Delete vector store data from kvstore."""
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(key) keys_to_delete = [
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
]
for key in keys_to_delete:
try:
await self.kvstore.delete(key)
except Exception as e:
logger.warning(f"Failed to delete key {key}: {e}")
continue

View file

@ -452,8 +452,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
) )
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# load any existing OpenAI vector stores # Load existing OpenAI vector stores into the in-memory cache
self.openai_vector_stores = await self._load_openai_vector_stores() await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection # nothing to do since we don't maintain a persistent connection
@ -501,41 +501,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await self.cache[vector_db_id].index.delete() await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id] del self.cache[vector_db_id]
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _save_openai_vector_store_file( async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None: ) -> None:

View file

@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry, build_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"llama3.1:8b-instruct-fp16", "llama3.1:8b-instruct-fp16",
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
"llama3.3:70b", "llama3.3:70b",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
# The Llama Guard models don't have their full fp16 versions
# so we are going to alias their default version to the canonical SKU
build_hf_repo_model_entry(
"llama-guard3:8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"llama-guard3:1b",
CoreModelId.llama_guard_3_1b.value,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="all-minilm:l6-v2", provider_model_id="all-minilm:l6-v2",
aliases=["all-minilm"], aliases=["all-minilm"],
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
"context_length": 8192, "context_length": 8192,
}, },
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -8,7 +8,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.providers.utils.kvstore.config import KVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
uri: str = Field(description="The URI of the Milvus server") uri: str = Field(description="The URI of the Milvus server")
token: str | None = Field(description="The token of the Milvus server") token: str | None = Field(description="The token of the Milvus server")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) kvstore: KVStoreConfig = Field(description="Config for KV store backend")
# This configuration allows additional fields to be passed through to the underlying Milvus client. # This configuration allows additional fields to be passed through to the underlying Milvus client.
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. # See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"} return {
"uri": "${env.MILVUS_ENDPOINT}",
"token": "${env.MILVUS_TOKEN}",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="milvus_remote_registry.db",
),
}

View file

@ -12,7 +12,7 @@ import re
from typing import Any from typing import Any
from numpy.typing import NDArray from numpy.typing import NDArray
from pymilvus import DataType, MilvusClient from pymilvus import DataType, Function, FunctionType, MilvusClient
from llama_stack.apis.files.files import Files from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import Inference, InterleavedContent from llama_stack.apis.inference import Inference, InterleavedContent
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
) )
if not await asyncio.to_thread(self.client.has_collection, self.collection_name): if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
# Create schema for vector search
schema = self.client.create_schema()
schema.add_field(
field_name="chunk_id",
datatype=DataType.VARCHAR,
is_primary=True,
max_length=100,
)
schema.add_field(
field_name="content",
datatype=DataType.VARCHAR,
max_length=65535,
enable_analyzer=True, # Enable text analysis for BM25
)
schema.add_field(
field_name="vector",
datatype=DataType.FLOAT_VECTOR,
dim=len(embeddings[0]),
)
schema.add_field(
field_name="chunk_content",
datatype=DataType.JSON,
)
# Add sparse vector field for BM25 (required by the function)
schema.add_field(
field_name="sparse",
datatype=DataType.SPARSE_FLOAT_VECTOR,
)
# Create indexes
index_params = self.client.prepare_index_params()
index_params.add_index(
field_name="vector",
index_type="FLAT",
metric_type="COSINE",
)
# Add index for sparse field (required by BM25 function)
index_params.add_index(
field_name="sparse",
index_type="SPARSE_INVERTED_INDEX",
metric_type="BM25",
)
# Add BM25 function for full-text search
bm25_function = Function(
name="text_bm25_emb",
input_field_names=["content"],
output_field_names=["sparse"],
function_type=FunctionType.BM25,
)
schema.add_function(bm25_function)
await asyncio.to_thread( await asyncio.to_thread(
self.client.create_collection, self.client.create_collection,
self.collection_name, self.collection_name,
dimension=len(embeddings[0]), schema=schema,
auto_id=True, index_params=index_params,
consistency_level=self.consistency_level, consistency_level=self.consistency_level,
) )
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
data.append( data.append(
{ {
"chunk_id": chunk.chunk_id, "chunk_id": chunk.chunk_id,
"content": chunk.content,
"vector": embedding, "vector": embedding,
"chunk_content": chunk.model_dump(), "chunk_content": chunk.model_dump(),
# sparse field will be handled by BM25 function automatically
} }
) )
try: try:
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
self.client.search, self.client.search,
collection_name=self.collection_name, collection_name=self.collection_name,
data=[embedding], data=[embedding],
anns_field="vector",
limit=k, limit=k,
output_fields=["*"], output_fields=["*"],
search_params={"params": {"radius": score_threshold}}, search_params={"params": {"radius": score_threshold}},
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
k: int, k: int,
score_threshold: float, score_threshold: float,
) -> QueryChunksResponse: ) -> QueryChunksResponse:
raise NotImplementedError("Keyword search is not supported in Milvus") """
Perform BM25-based keyword search using Milvus's built-in full-text search.
"""
try:
# Use Milvus's built-in BM25 search
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[query_string], # Raw text query
anns_field="sparse", # Use sparse field for BM25
output_fields=["chunk_content"], # Output the chunk content
limit=k,
search_params={
"params": {
"drop_ratio_search": 0.2, # Ignore low-importance terms
}
},
)
chunks = []
scores = []
for res in search_res[0]:
chunk = Chunk(**res["entity"]["chunk_content"])
chunks.append(chunk)
scores.append(res["distance"]) # BM25 score from Milvus
# Filter by score threshold
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
filtered_scores = [score for score in scores if score >= score_threshold]
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
except Exception as e:
logger.error(f"Error performing BM25 search: {e}")
# Fallback to simple text search
return await self._fallback_keyword_search(query_string, k, score_threshold)
async def _fallback_keyword_search(
self,
query_string: str,
k: int,
score_threshold: float,
) -> QueryChunksResponse:
"""
Fallback to simple text search when BM25 search is not available.
"""
# Simple text search using content field
search_res = await asyncio.to_thread(
self.client.query,
collection_name=self.collection_name,
filter='content like "%{content}%"',
filter_params={"content": query_string},
output_fields=["*"],
limit=k,
)
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
scores = [1.0] * len(chunks) # Simple binary score for text search
return QueryChunksResponse(chunks=chunks, scores=scores)
async def query_hybrid( async def query_hybrid(
self, self,
@ -179,7 +293,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
uri = os.path.expanduser(self.config.db_path) uri = os.path.expanduser(self.config.db_path)
self.client = MilvusClient(uri=uri) self.client = MilvusClient(uri=uri)
self.openai_vector_stores = await self._load_openai_vector_stores() # Load existing OpenAI vector stores into the in-memory cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
self.client.close() self.client.close()
@ -246,38 +361,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if not index: if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
if params and params.get("mode") == "keyword":
# Check if this is inline Milvus (Milvus-Lite)
if hasattr(self.config, "db_path"):
raise NotImplementedError(
"Keyword search is not supported in Milvus-Lite. "
"Please use a remote Milvus server for keyword search functionality."
)
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored = await self.kvstore.values_in_range(start_key, end_key)
return {json.loads(s)["id"]: json.loads(s) for s in stored}
async def _save_openai_vector_store_file( async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None: ) -> None:

View file

@ -8,6 +8,10 @@ from typing import Any
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel):
db: str | None = Field(default="postgres") db: str | None = Field(default="postgres")
user: str | None = Field(default="postgres") user: str | None = Field(default="postgres")
password: str | None = Field(default="mysecretpassword") password: str | None = Field(default="mysecretpassword")
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
@classmethod @classmethod
def sample_run_config( def sample_run_config(
cls, cls,
__distro_dir__: str,
host: str = "${env.PGVECTOR_HOST:=localhost}", host: str = "${env.PGVECTOR_HOST:=localhost}",
port: int = "${env.PGVECTOR_PORT:=5432}", port: int = "${env.PGVECTOR_PORT:=5432}",
db: str = "${env.PGVECTOR_DB}", db: str = "${env.PGVECTOR_DB}",
@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel):
password: str = "${env.PGVECTOR_PASSWORD}", password: str = "${env.PGVECTOR_PASSWORD}",
**kwargs: Any, **kwargs: Any,
) -> dict[str, Any]: ) -> dict[str, Any]:
return {"host": host, "port": port, "db": db, "user": user, "password": password} return {
"host": host,
"port": port,
"db": db,
"user": user,
"password": password,
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="pgvector_registry.db",
),
}

View file

@ -13,24 +13,18 @@ from psycopg2 import sql
from psycopg2.extras import Json, execute_values from psycopg2.extras import Json, execute_values
from pydantic import BaseModel, TypeAdapter from pydantic import BaseModel, TypeAdapter
from llama_stack.apis.files.files import Files
from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
QueryChunksResponse, QueryChunksResponse,
SearchRankingOptions,
VectorIO, VectorIO,
VectorStoreChunkingStrategy,
VectorStoreDeleteResponse,
VectorStoreFileContentsResponse,
VectorStoreFileObject,
VectorStoreFileStatus,
VectorStoreListFilesResponse,
VectorStoreListResponse,
VectorStoreObject,
VectorStoreSearchResponsePage,
) )
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
def check_extension_version(cur): def check_extension_version(cur):
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'") cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
@ -69,7 +70,7 @@ def load_models(cur, cls):
class PGVectorIndex(EmbeddingIndex): class PGVectorIndex(EmbeddingIndex):
def __init__(self, vector_db: VectorDB, dimension: int, conn): def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
self.conn = conn self.conn = conn
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur: with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
# Sanitize the table name by replacing hyphens with underscores # Sanitize the table name by replacing hyphens with underscores
@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex):
# when created with patterns like "test-vector-db-{uuid4()}" # when created with patterns like "test-vector-db-{uuid4()}"
sanitized_identifier = vector_db.identifier.replace("-", "_") sanitized_identifier = vector_db.identifier.replace("-", "_")
self.table_name = f"vector_store_{sanitized_identifier}" self.table_name = f"vector_store_{sanitized_identifier}"
self.kvstore = kvstore
cur.execute( cur.execute(
f""" f"""
@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex):
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None: def __init__(
self,
config: PGVectorVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None = None,
) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.conn = None self.conn = None
self.cache = {} self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_store: dict[str, dict[str, Any]] = {}
self.metadatadata_collection_name = "openai_vector_stores_metadata"
async def initialize(self) -> None: async def initialize(self) -> None:
log.info(f"Initializing PGVector memory adapter with config: {self.config}") log.info(f"Initializing PGVector memory adapter with config: {self.config}")
self.kvstore = await kvstore_impl(self.config.kvstore)
await self.initialize_openai_vector_stores()
try: try:
self.conn = psycopg2.connect( self.conn = psycopg2.connect(
host=self.config.host, host=self.config.host,
@ -201,14 +216,31 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
log.info("Connection to PGVector database server closed") log.info("Connection to PGVector database server closed")
async def register_vector_db(self, vector_db: VectorDB) -> None: async def register_vector_db(self, vector_db: VectorDB) -> None:
# Persist vector DB metadata in the KV store
assert self.kvstore is not None
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
# Upsert model metadata in Postgres
upsert_models(self.conn, [(vector_db.identifier, vector_db)]) upsert_models(self.conn, [(vector_db.identifier, vector_db)])
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn) # Create and cache the PGVector index table for the vector DB
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) index = VectorDBWithIndex(
vector_db,
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
inference_api=self.inference_api,
)
self.cache[vector_db.identifier] = index
async def unregister_vector_db(self, vector_db_id: str) -> None: async def unregister_vector_db(self, vector_db_id: str) -> None:
await self.cache[vector_db_id].index.delete() # Remove provider index and cache
del self.cache[vector_db_id] if vector_db_id in self.cache:
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
# Delete vector DB metadata from KV store
assert self.kvstore is not None
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
async def insert_chunks( async def insert_chunks(
self, self,
@ -237,106 +269,20 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api) self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
return self.cache[vector_db_id] return self.cache[vector_db_id]
async def openai_create_vector_store( # OpenAI Vector Stores File operations are not supported in PGVector
self, async def _save_openai_vector_store_file(
name: str, self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
file_ids: list[str] | None = None, ) -> None:
expires_after: dict[str, Any] | None = None,
chunking_strategy: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
embedding_model: str | None = None,
embedding_dimension: int | None = 384,
provider_id: str | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_list_vector_stores( async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
self,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
) -> VectorStoreListResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_retrieve_vector_store( async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
self,
vector_store_id: str,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_update_vector_store( async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
self,
vector_store_id: str,
name: str | None = None,
expires_after: dict[str, Any] | None = None,
metadata: dict[str, Any] | None = None,
) -> VectorStoreObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_delete_vector_store( async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
self,
vector_store_id: str,
) -> VectorStoreDeleteResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_search_vector_store(
self,
vector_store_id: str,
query: str | list[str],
filters: dict[str, Any] | None = None,
max_num_results: int | None = 10,
ranking_options: SearchRankingOptions | None = None,
rewrite_query: bool | None = False,
search_mode: str | None = "vector",
) -> VectorStoreSearchResponsePage:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_attach_file_to_vector_store(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
chunking_strategy: VectorStoreChunkingStrategy | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_list_files_in_vector_store(
self,
vector_store_id: str,
limit: int | None = 20,
order: str | None = "desc",
after: str | None = None,
before: str | None = None,
filter: VectorStoreFileStatus | None = None,
) -> VectorStoreListFilesResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_retrieve_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_retrieve_vector_store_file_contents(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileContentsResponse:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_update_vector_store_file(
self,
vector_store_id: str,
file_id: str,
attributes: dict[str, Any] | None = None,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
async def openai_delete_vector_store_file(
self,
vector_store_id: str,
file_id: str,
) -> VectorStoreFileObject:
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector") raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")

View file

@ -6,15 +6,26 @@
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class WeaviateRequestProviderData(BaseModel): class WeaviateRequestProviderData(BaseModel):
weaviate_api_key: str weaviate_api_key: str
weaviate_cluster_url: str weaviate_cluster_url: str
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
class WeaviateVectorIOConfig(BaseModel): class WeaviateVectorIOConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
return {} return {
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="weaviate_registry.db",
),
}

View file

@ -14,10 +14,13 @@ from weaviate.classes.init import Auth
from weaviate.classes.query import Filter from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.files.files import Files
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex, EmbeddingIndex,
VectorDBWithIndex, VectorDBWithIndex,
@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
class WeaviateIndex(EmbeddingIndex): class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str): def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None):
self.client = client self.client = client
self.collection_name = collection_name self.collection_name = collection_name
self.kvstore = kvstore
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), ( assert len(chunks) == len(embeddings), (
@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter(
NeedsRequestProviderData, NeedsRequestProviderData,
VectorDBsProtocolPrivate, VectorDBsProtocolPrivate,
): ):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None: def __init__(
self,
config: WeaviateVectorIOConfig,
inference_api: Api.inference,
files_api: Files | None,
) -> None:
self.config = config self.config = config
self.inference_api = inference_api self.inference_api = inference_api
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
self.files_api = files_api
self.kvstore: KVStore | None = None
self.vector_db_store = None
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.metadata_collection_name = "openai_vector_stores_metadata"
def _get_client(self) -> weaviate.Client: def _get_client(self) -> weaviate.Client:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter(
return client return client
async def initialize(self) -> None: async def initialize(self) -> None:
pass """Set up KV store and load existing vector DBs and OpenAI vector stores."""
# Initialize KV store for metadata
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing vector DB definitions
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored = await self.kvstore.values_in_range(start_key, end_key)
for raw in stored:
vector_db = VectorDB.model_validate_json(raw)
client = self._get_client()
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db=vector_db,
index=idx,
inference_api=self.inference_api,
)
# Load OpenAI vector stores metadata into cache
await self.initialize_openai_vector_stores()
async def shutdown(self) -> None: async def shutdown(self) -> None:
for client in self.client_cache.values(): for client in self.client_cache.values():
@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter(
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params) return await index.query_chunks(query, params)
# OpenAI Vector Stores File operations are not supported in Weaviate
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")

View file

@ -83,9 +83,37 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def get_llama_model(self, provider_model_id: str) -> str | None: def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None) return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def check_model_availability(self, model: str) -> bool:
"""
Check if a specific model is available from the provider (non-static check).
This is for subclassing purposes, so providers can check if a specific
model is currently available for use through dynamic means (e.g., API calls).
This method should NOT check statically configured model entries in
`self.alias_to_provider_id_map` - that is handled separately in register_model.
Default implementation returns False (no dynamic models available).
:param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
"""
return False
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
if not (supported_model_id := self.get_provider_model_id(model.provider_resource_id)): # Check if model is supported in static configuration
raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) supported_model_id = self.get_provider_model_id(model.provider_resource_id)
# If not found in static config, check if it's available dynamically from provider
if not supported_model_id:
if await self.check_model_availability(model.provider_resource_id):
supported_model_id = model.provider_resource_id
else:
# note: we cannot provide a complete list of supported models without
# getting a complete list from the provider, so we return "..."
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
provider_resource_id = self.get_provider_model_id(model.model_id) provider_resource_id = self.get_provider_model_id(model.model_id)
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model # embedding models are always registered by their provider model id and does not need to be mapped to a llama model
@ -114,6 +142,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
) )
# Register the model alias, ensuring it maps to the correct provider model id
self.alias_to_provider_id_map[model.model_id] = supported_model_id self.alias_to_provider_id_map[model.model_id] = supported_model_id
return model return model

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
import logging import logging
import mimetypes import mimetypes
import time import time
@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
VectorStoreSearchResponse, VectorStoreSearchResponse,
VectorStoreSearchResponsePage, VectorStoreSearchResponsePage,
) )
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
# These should be provided by the implementing class # These should be provided by the implementing class
openai_vector_stores: dict[str, dict[str, Any]] openai_vector_stores: dict[str, dict[str, Any]]
files_api: Files | None files_api: Files | None
# KV store for persisting OpenAI vector store metadata
kvstore: KVStore | None
@abstractmethod
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to persistent storage.""" """Save vector store metadata to persistent storage."""
pass assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
self.openai_vector_stores[store_id] = store_info
@abstractmethod
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from persistent storage.""" """Load all vector store metadata from persistent storage."""
pass assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_data = await self.kvstore.values_in_range(start_key, end_key)
stores: dict[str, dict[str, Any]] = {}
for item in stored_data:
info = json.loads(item)
stores[info["id"]] = info
return stores
@abstractmethod
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None: async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in persistent storage.""" """Update vector store metadata in persistent storage."""
pass assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
# update in-memory cache
self.openai_vector_stores[store_id] = store_info
@abstractmethod
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from persistent storage.""" """Delete vector store metadata from persistent storage."""
pass assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
# remove from in-memory cache
self.openai_vector_stores.pop(store_id, None)
@abstractmethod @abstractmethod
async def _save_openai_vector_store_file( async def _save_openai_vector_store_file(
@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
"""Unregister a vector database (provider-specific implementation).""" """Unregister a vector database (provider-specific implementation)."""
pass pass
async def initialize_openai_vector_stores(self) -> None:
"""Load existing OpenAI vector stores into the in-memory cache."""
self.openai_vector_stores = await self._load_openai_vector_stores()
@abstractmethod @abstractmethod
async def insert_chunks( async def insert_chunks(
self, self,

View file

@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
default_models = get_model_registry(available_models) default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name="nvidia",
distro_type="self_hosted", distro_type="self_hosted",

View file

@ -128,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.ENABLE_PGVECTOR:+pgvector}", provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
provider_type="remote::pgvector", provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config( config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}", db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}", user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}", password="${env.PGVECTOR_PASSWORD:=}",
@ -146,7 +147,8 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
default_models = get_model_registry(available_models) + [ models, _ = get_model_registry(available_models)
default_models = models + [
ModelInput( ModelInput(
model_id="meta-llama/Llama-3.3-70B-Instruct", model_id="meta-llama/Llama-3.3-70B-Instruct",
provider_id="groq", provider_id="groq",

View file

@ -54,6 +54,9 @@ providers:
db: ${env.PGVECTOR_DB:=} db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=} user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=} password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -166,6 +166,9 @@ providers:
db: ${env.PGVECTOR_DB:=} db: ${env.PGVECTOR_DB:=}
user: ${env.PGVECTOR_USER:=} user: ${env.PGVECTOR_USER:=}
password: ${env.PGVECTOR_PASSWORD:=} password: ${env.PGVECTOR_PASSWORD:=}
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
files: files:
- provider_id: meta-reference-files - provider_id: meta-reference-files
provider_type: inline::localfs provider_type: inline::localfs
@ -1171,24 +1174,8 @@ models:
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding model_type: embedding
shields: shields:
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__} - shield_id: ${env.SAFETY_MODEL:=__disabled__}
provider_id: llama-guard provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import (
ModelInput, ModelInput,
Provider, Provider,
ProviderSpec, ProviderSpec,
ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import ( from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.anthropic.models import (
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import ( from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES, MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.bedrock.models import (
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import ( from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES, MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.cerebras.models import (
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import ( from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES, MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.databricks.databricks import (
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import ( from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.fireworks.models import (
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import ( from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.gemini.models import (
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import ( from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES, MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.groq.models import (
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import ( from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES, MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.nvidia.models import (
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import ( from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.openai.models import (
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import ( from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES, MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.runpod.runpod import (
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import ( from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.sambanova.models import (
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import ( from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.together.models import (
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
@ -111,6 +74,7 @@ from llama_stack.templates.template import (
DistributionTemplate, DistributionTemplate,
RunConfigSettings, RunConfigSettings,
get_model_registry, get_model_registry,
get_shield_registry,
) )
@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]: def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type.""" """Get model entries for a specific provider type."""
safety_model_entries_map = { safety_model_entries_map = {
"openai": OPENAI_SAFETY_MODELS_ENTRIES, "ollama": [
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
"groq": GROQ_SAFETY_MODELS_ENTRIES,
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry( ProviderModelEntry(
provider_model_id="llama-guard3:1b", provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm, model_type=ModelType.llm,
), ),
] ],
}
return safety_model_entries_map.get(provider_type, []) return safety_model_entries_map.get(provider_type, [])
@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
# build a list of shields for all possible providers # build a list of shields for all possible providers
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]: def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
shields = [] available_models = {}
for provider in providers: for provider in providers:
provider_type = provider.provider_type.split("::")[1] provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type) safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0: if len(safety_model_entries) == 0:
continue continue
if provider.provider_id:
shield_id = provider.provider_id env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
else: provider_id = f"${{env.{env_var}:=__disabled__}}"
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
for safety_model_entry in safety_model_entries: available_models[provider_id] = safety_model_entries
print(f"provider.provider_id: {provider.provider_id}")
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}") return available_models
shields.append(
ShieldInput(
provider_id="llama-guard",
shield_id=shield_id,
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
)
)
return shields
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
@ -300,6 +241,7 @@ def get_distribution_template() -> DistributionTemplate:
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}", provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
provider_type="remote::pgvector", provider_type="remote::pgvector",
config=PGVectorVectorIOConfig.sample_run_config( config=PGVectorVectorIOConfig.sample_run_config(
f"~/.llama/distributions/{name}",
db="${env.PGVECTOR_DB:=}", db="${env.PGVECTOR_DB:=}",
user="${env.PGVECTOR_USER:=}", user="${env.PGVECTOR_USER:=}",
password="${env.PGVECTOR_PASSWORD:=}", password="${env.PGVECTOR_PASSWORD:=}",
@ -307,8 +249,6 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
shields = get_shields_for_providers(remote_inference_providers)
providers = { providers = {
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ([p.provider_type for p in vector_io_providers]), "vector_io": ([p.provider_type for p in vector_io_providers]),
@ -361,7 +301,10 @@ def get_distribution_template() -> DistributionTemplate:
}, },
) )
default_models = get_model_registry(available_models) default_models, ids_conflict_in_models = get_model_registry(available_models)
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,

View file

@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
def get_model_registry( def get_model_registry(
available_models: dict[str, list[ProviderModelEntry]], available_models: dict[str, list[ProviderModelEntry]],
) -> list[ModelInput]: ) -> tuple[list[ModelInput], bool]:
models = [] models = []
# check for conflicts in model ids # check for conflicts in model ids
@ -74,7 +74,50 @@ def get_model_registry(
metadata=entry.metadata, metadata=entry.metadata,
) )
) )
return models return models, ids_conflict
def get_shield_registry(
available_safety_models: dict[str, list[ProviderModelEntry]],
ids_conflict_in_models: bool,
) -> list[ShieldInput]:
shields = []
# check for conflicts in shield ids
all_ids = set()
ids_conflict = False
for _, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
if model_id in all_ids:
ids_conflict = True
rich.print(
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
)
break
all_ids.update(ids)
if ids_conflict:
break
if ids_conflict:
break
for provider_id, entries in available_safety_models.items():
for entry in entries:
ids = [entry.provider_model_id] + entry.aliases
for model_id in ids:
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
shields.append(
ShieldInput(
shield_id=identifier,
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
if ids_conflict_in_models
else entry.provider_model_id,
)
)
return shields
class DefaultModel(BaseModel): class DefaultModel(BaseModel):

View file

@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
}, },
) )
default_models = get_model_registry(available_models) default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="watsonx", name="watsonx",
distro_type="remote_hosted", distro_type="remote_hosted",

View file

@ -0,0 +1,82 @@
"use client";
import { useEffect, useState } from "react";
import { useParams, useRouter } from "next/navigation";
import { useAuthClient } from "@/hooks/use-auth-client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
import { VectorStoreDetailView } from "@/components/vector-stores/vector-store-detail";
export default function VectorStoreDetailPage() {
const params = useParams();
const id = params.id as string;
const client = useAuthClient();
const router = useRouter();
const [store, setStore] = useState<VectorStore | null>(null);
const [files, setFiles] = useState<VectorStoreFile[]>([]);
const [isLoadingStore, setIsLoadingStore] = useState(true);
const [isLoadingFiles, setIsLoadingFiles] = useState(true);
const [errorStore, setErrorStore] = useState<Error | null>(null);
const [errorFiles, setErrorFiles] = useState<Error | null>(null);
useEffect(() => {
if (!id) {
setErrorStore(new Error("Vector Store ID is missing."));
setIsLoadingStore(false);
return;
}
const fetchStore = async () => {
setIsLoadingStore(true);
setErrorStore(null);
try {
const response = await client.vectorStores.retrieve(id);
setStore(response as VectorStore);
} catch (err) {
setErrorStore(
err instanceof Error
? err
: new Error("Failed to load vector store."),
);
} finally {
setIsLoadingStore(false);
}
};
fetchStore();
}, [id, client]);
useEffect(() => {
if (!id) {
setErrorFiles(new Error("Vector Store ID is missing."));
setIsLoadingFiles(false);
return;
}
const fetchFiles = async () => {
setIsLoadingFiles(true);
setErrorFiles(null);
try {
const result = await client.vectorStores.files.list(id as any);
setFiles((result as any).data);
} catch (err) {
setErrorFiles(
err instanceof Error ? err : new Error("Failed to load files."),
);
} finally {
setIsLoadingFiles(false);
}
};
fetchFiles();
}, [id]);
return (
<VectorStoreDetailView
store={store}
files={files}
isLoadingStore={isLoadingStore}
isLoadingFiles={isLoadingFiles}
errorStore={errorStore}
errorFiles={errorFiles}
id={id}
/>
);
}

View file

@ -0,0 +1,16 @@
"use client";
import React from "react";
import LogsLayout from "@/components/layout/logs-layout";
export default function VectorStoresLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<LogsLayout sectionLabel="Vector Stores" basePath="/logs/vector-stores">
{children}
</LogsLayout>
);
}

View file

@ -0,0 +1,121 @@
"use client";
import React from "react";
import { useAuthClient } from "@/hooks/use-auth-client";
import type {
ListVectorStoresResponse,
VectorStore,
} from "llama-stack-client/resources/vector-stores/vector-stores";
import { useRouter } from "next/navigation";
import { usePagination } from "@/hooks/use-pagination";
import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Skeleton } from "@/components/ui/skeleton";
export default function VectorStoresPage() {
const client = useAuthClient();
const router = useRouter();
const {
data: stores,
status,
hasMore,
error,
loadMore,
} = usePagination<VectorStore>({
limit: 20,
order: "desc",
fetchFunction: async (client, params) => {
const response = await client.vectorStores.list({
after: params.after,
limit: params.limit,
order: params.order,
} as any);
return response as ListVectorStoresResponse;
},
errorMessagePrefix: "vector stores",
});
// Auto-load all pages for infinite scroll behavior (like Responses)
React.useEffect(() => {
if (status === "idle" && hasMore) {
loadMore();
}
}, [status, hasMore, loadMore]);
if (status === "loading") {
return (
<div className="space-y-2">
<Skeleton className="h-8 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
</div>
);
}
if (status === "error") {
return <div className="text-destructive">Error: {error?.message}</div>;
}
if (!stores || stores.length === 0) {
return <p>No vector stores found.</p>;
}
return (
<div className="overflow-auto flex-1 min-h-0">
<Table>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Name</TableHead>
<TableHead>Created</TableHead>
<TableHead>Completed</TableHead>
<TableHead>Cancelled</TableHead>
<TableHead>Failed</TableHead>
<TableHead>In Progress</TableHead>
<TableHead>Total</TableHead>
<TableHead>Usage Bytes</TableHead>
<TableHead>Provider ID</TableHead>
<TableHead>Provider Vector DB ID</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{stores.map((store) => {
const fileCounts = store.file_counts;
const metadata = store.metadata || {};
const providerId = metadata.provider_id ?? "";
const providerDbId = metadata.provider_vector_db_id ?? "";
return (
<TableRow
key={store.id}
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
className="cursor-pointer hover:bg-muted/50"
>
<TableCell>{store.id}</TableCell>
<TableCell>{store.name}</TableCell>
<TableCell>
{new Date(store.created_at * 1000).toLocaleString()}
</TableCell>
<TableCell>{fileCounts.completed}</TableCell>
<TableCell>{fileCounts.cancelled}</TableCell>
<TableCell>{fileCounts.failed}</TableCell>
<TableCell>{fileCounts.in_progress}</TableCell>
<TableCell>{fileCounts.total}</TableCell>
<TableCell>{store.usage_bytes}</TableCell>
<TableCell>{providerId}</TableCell>
<TableCell>{providerDbId}</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
);
}

View file

@ -1,6 +1,11 @@
"use client"; "use client";
import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react"; import {
MessageSquareText,
MessagesSquare,
MoveUpRight,
Database,
} from "lucide-react";
import Link from "next/link"; import Link from "next/link";
import { usePathname } from "next/navigation"; import { usePathname } from "next/navigation";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
@ -28,6 +33,11 @@ const logItems = [
url: "/logs/responses", url: "/logs/responses",
icon: MessagesSquare, icon: MessagesSquare,
}, },
{
title: "Vector Stores",
url: "/logs/vector-stores",
icon: Database,
},
{ {
title: "Documentation", title: "Documentation",
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html", url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
@ -57,13 +67,13 @@ export function AppSidebar() {
className={cn( className={cn(
"justify-start", "justify-start",
isActive && isActive &&
"bg-gray-200 hover:bg-gray-200 text-primary hover:text-primary", "bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100",
)} )}
> >
<Link href={item.url}> <Link href={item.url}>
<item.icon <item.icon
className={cn( className={cn(
isActive && "text-primary", isActive && "text-gray-900 dark:text-gray-100",
"mr-2 h-4 w-4", "mr-2 h-4 w-4",
)} )}
/> />

View file

@ -93,7 +93,9 @@ export function PropertyItem({
> >
<strong>{label}:</strong>{" "} <strong>{label}:</strong>{" "}
{typeof value === "string" || typeof value === "number" ? ( {typeof value === "string" || typeof value === "number" ? (
<span className="text-gray-900 font-medium">{value}</span> <span className="text-gray-900 dark:text-gray-100 font-medium">
{value}
</span>
) : ( ) : (
value value
)} )}
@ -112,7 +114,9 @@ export function PropertiesCard({ children }: PropertiesCardProps) {
<CardTitle>Properties</CardTitle> <CardTitle>Properties</CardTitle>
</CardHeader> </CardHeader>
<CardContent> <CardContent>
<ul className="space-y-2 text-sm text-gray-600">{children}</ul> <ul className="space-y-2 text-sm text-gray-600 dark:text-gray-400">
{children}
</ul>
</CardContent> </CardContent>
</Card> </Card>
); );

View file

@ -17,10 +17,10 @@ export const MessageBlock: React.FC<MessageBlockProps> = ({
}) => { }) => {
return ( return (
<div className={`mb-4 ${className}`}> <div className={`mb-4 ${className}`}>
<p className="py-1 font-semibold text-gray-800 mb-1"> <p className="py-1 font-semibold text-muted-foreground mb-1">
{label} {label}
{labelDetail && ( {labelDetail && (
<span className="text-xs text-gray-500 font-normal ml-1"> <span className="text-xs text-muted-foreground font-normal ml-1">
{labelDetail} {labelDetail}
</span> </span>
)} )}

View file

@ -0,0 +1,128 @@
"use client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
import {
DetailLoadingView,
DetailErrorView,
DetailNotFoundView,
DetailLayout,
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
interface VectorStoreDetailViewProps {
store: VectorStore | null;
files: VectorStoreFile[];
isLoadingStore: boolean;
isLoadingFiles: boolean;
errorStore: Error | null;
errorFiles: Error | null;
id: string;
}
export function VectorStoreDetailView({
store,
files,
isLoadingStore,
isLoadingFiles,
errorStore,
errorFiles,
id,
}: VectorStoreDetailViewProps) {
const title = "Vector Store Details";
if (errorStore) {
return <DetailErrorView title={title} id={id} error={errorStore} />;
}
if (isLoadingStore) {
return <DetailLoadingView title={title} />;
}
if (!store) {
return <DetailNotFoundView title={title} id={id} />;
}
const mainContent = (
<>
<Card>
<CardHeader>
<CardTitle>Files</CardTitle>
</CardHeader>
<CardContent>
{isLoadingFiles ? (
<Skeleton className="h-4 w-full" />
) : errorFiles ? (
<div className="text-destructive text-sm">
Error loading files: {errorFiles.message}
</div>
) : files.length > 0 ? (
<Table>
<TableCaption>Files in this vector store</TableCaption>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Status</TableHead>
<TableHead>Created</TableHead>
<TableHead>Usage Bytes</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{files.map((file) => (
<TableRow key={file.id}>
<TableCell>{file.id}</TableCell>
<TableCell>{file.status}</TableCell>
<TableCell>
{new Date(file.created_at * 1000).toLocaleString()}
</TableCell>
<TableCell>{file.usage_bytes}</TableCell>
</TableRow>
))}
</TableBody>
</Table>
) : (
<p className="text-gray-500 italic text-sm">
No files in this vector store.
</p>
)}
</CardContent>
</Card>
</>
);
const sidebar = (
<PropertiesCard>
<PropertyItem label="ID" value={store.id} />
<PropertyItem label="Name" value={store.name || ""} />
<PropertyItem
label="Created"
value={new Date(store.created_at * 1000).toLocaleString()}
/>
<PropertyItem label="Status" value={store.status} />
<PropertyItem label="Total Files" value={store.file_counts.total} />
<PropertyItem label="Usage Bytes" value={store.usage_bytes} />
<PropertyItem
label="Provider ID"
value={(store.metadata.provider_id as string) || ""}
/>
<PropertyItem
label="Provider DB ID"
value={(store.metadata.provider_vector_db_id as string) || ""}
/>
</PropertiesCard>
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
);
}

View file

@ -15,7 +15,7 @@
"@radix-ui/react-tooltip": "^1.2.6", "@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"llama-stack-client": "0.2.13", "llama-stack-client": "^0.2.14",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",
@ -676,406 +676,6 @@
"tslib": "^2.4.0" "tslib": "^2.4.0"
} }
}, },
"node_modules/@esbuild/aix-ppc64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.5.tgz",
"integrity": "sha512-9o3TMmpmftaCMepOdA5k/yDw8SfInyzWWTjYTFCX3kPSDJMROQTb8jg+h9Cnwnmm1vOzvxN7gIfB5V2ewpjtGA==",
"cpu": [
"ppc64"
],
"license": "MIT",
"optional": true,
"os": [
"aix"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/android-arm": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.5.tgz",
"integrity": "sha512-AdJKSPeEHgi7/ZhuIPtcQKr5RQdo6OO2IL87JkianiMYMPbCtot9fxPbrMiBADOWWm3T2si9stAiVsGbTQFkbA==",
"cpu": [
"arm"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/android-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.5.tgz",
"integrity": "sha512-VGzGhj4lJO+TVGV1v8ntCZWJktV7SGCs3Pn1GRWI1SBFtRALoomm8k5E9Pmwg3HOAal2VDc2F9+PM/rEY6oIDg==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/android-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.5.tgz",
"integrity": "sha512-D2GyJT1kjvO//drbRT3Hib9XPwQeWd9vZoBJn+bu/lVsOZ13cqNdDeqIF/xQ5/VmWvMduP6AmXvylO/PIc2isw==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/darwin-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.5.tgz",
"integrity": "sha512-GtaBgammVvdF7aPIgH2jxMDdivezgFu6iKpmT+48+F8Hhg5J/sfnDieg0aeG/jfSvkYQU2/pceFPDKlqZzwnfQ==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/darwin-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.5.tgz",
"integrity": "sha512-1iT4FVL0dJ76/q1wd7XDsXrSW+oLoquptvh4CLR4kITDtqi2e/xwXwdCVH8hVHU43wgJdsq7Gxuzcs6Iq/7bxQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/freebsd-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.5.tgz",
"integrity": "sha512-nk4tGP3JThz4La38Uy/gzyXtpkPW8zSAmoUhK9xKKXdBCzKODMc2adkB2+8om9BDYugz+uGV7sLmpTYzvmz6Sw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"freebsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/freebsd-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.5.tgz",
"integrity": "sha512-PrikaNjiXdR2laW6OIjlbeuCPrPaAl0IwPIaRv+SMV8CiM8i2LqVUHFC1+8eORgWyY7yhQY+2U2fA55mBzReaw==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"freebsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-arm": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.5.tgz",
"integrity": "sha512-cPzojwW2okgh7ZlRpcBEtsX7WBuqbLrNXqLU89GxWbNt6uIg78ET82qifUy3W6OVww6ZWobWub5oqZOVtwolfw==",
"cpu": [
"arm"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.5.tgz",
"integrity": "sha512-Z9kfb1v6ZlGbWj8EJk9T6czVEjjq2ntSYLY2cw6pAZl4oKtfgQuS4HOq41M/BcoLPzrUbNd+R4BXFyH//nHxVg==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-ia32": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.5.tgz",
"integrity": "sha512-sQ7l00M8bSv36GLV95BVAdhJ2QsIbCuCjh/uYrWiMQSUuV+LpXwIqhgJDcvMTj+VsQmqAHL2yYaasENvJ7CDKA==",
"cpu": [
"ia32"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-loong64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.5.tgz",
"integrity": "sha512-0ur7ae16hDUC4OL5iEnDb0tZHDxYmuQyhKhsPBV8f99f6Z9KQM02g33f93rNH5A30agMS46u2HP6qTdEt6Q1kg==",
"cpu": [
"loong64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-mips64el": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.5.tgz",
"integrity": "sha512-kB/66P1OsHO5zLz0i6X0RxlQ+3cu0mkxS3TKFvkb5lin6uwZ/ttOkP3Z8lfR9mJOBk14ZwZ9182SIIWFGNmqmg==",
"cpu": [
"mips64el"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-ppc64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.5.tgz",
"integrity": "sha512-UZCmJ7r9X2fe2D6jBmkLBMQetXPXIsZjQJCjgwpVDz+YMcS6oFR27alkgGv3Oqkv07bxdvw7fyB71/olceJhkQ==",
"cpu": [
"ppc64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-riscv64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.5.tgz",
"integrity": "sha512-kTxwu4mLyeOlsVIFPfQo+fQJAV9mh24xL+y+Bm6ej067sYANjyEw1dNHmvoqxJUCMnkBdKpvOn0Ahql6+4VyeA==",
"cpu": [
"riscv64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-s390x": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.5.tgz",
"integrity": "sha512-K2dSKTKfmdh78uJ3NcWFiqyRrimfdinS5ErLSn3vluHNeHVnBAFWC8a4X5N+7FgVE1EjXS1QDZbpqZBjfrqMTQ==",
"cpu": [
"s390x"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.5.tgz",
"integrity": "sha512-uhj8N2obKTE6pSZ+aMUbqq+1nXxNjZIIjCjGLfsWvVpy7gKCOL6rsY1MhRh9zLtUtAI7vpgLMK6DxjO8Qm9lJw==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/netbsd-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.5.tgz",
"integrity": "sha512-pwHtMP9viAy1oHPvgxtOv+OkduK5ugofNTVDilIzBLpoWAM16r7b/mxBvfpuQDpRQFMfuVr5aLcn4yveGvBZvw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"netbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/netbsd-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.5.tgz",
"integrity": "sha512-WOb5fKrvVTRMfWFNCroYWWklbnXH0Q5rZppjq0vQIdlsQKuw6mdSihwSo4RV/YdQ5UCKKvBy7/0ZZYLBZKIbwQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"netbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/openbsd-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.5.tgz",
"integrity": "sha512-7A208+uQKgTxHd0G0uqZO8UjK2R0DDb4fDmERtARjSHWxqMTye4Erz4zZafx7Di9Cv+lNHYuncAkiGFySoD+Mw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"openbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/openbsd-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.5.tgz",
"integrity": "sha512-G4hE405ErTWraiZ8UiSoesH8DaCsMm0Cay4fsFWOOUcz8b8rC6uCvnagr+gnioEjWn0wC+o1/TAHt+It+MpIMg==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"openbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/sunos-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.5.tgz",
"integrity": "sha512-l+azKShMy7FxzY0Rj4RCt5VD/q8mG/e+mDivgspo+yL8zW7qEwctQ6YqKX34DTEleFAvCIUviCFX1SDZRSyMQA==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"sunos"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.5.tgz",
"integrity": "sha512-O2S7SNZzdcFG7eFKgvwUEZ2VG9D/sn/eIiz8XRZ1Q/DO5a3s76Xv0mdBzVM5j5R639lXQmPmSo0iRpHqUUrsxw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-ia32": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.5.tgz",
"integrity": "sha512-onOJ02pqs9h1iMJ1PQphR+VZv8qBMQ77Klcsqv9CNW2w6yLqoURLcgERAIurY6QE63bbLuqgP9ATqajFLK5AMQ==",
"cpu": [
"ia32"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.5.tgz",
"integrity": "sha512-TXv6YnJ8ZMVdX+SXWVBo/0p8LTcrUYngpWjvm91TMjjBQii7Oz11Lw5lbDV5Y0TzuhSJHwiH4hEtC1I42mMS0g==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@eslint-community/eslint-utils": { "node_modules/@eslint-community/eslint-utils": {
"version": "4.7.0", "version": "4.7.0",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz", "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz",
@ -5999,46 +5599,6 @@
"url": "https://github.com/sponsors/ljharb" "url": "https://github.com/sponsors/ljharb"
} }
}, },
"node_modules/esbuild": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.5.tgz",
"integrity": "sha512-P8OtKZRv/5J5hhz0cUAdu/cLuPIKXpQl1R9pZtvmHWQvrAUVd0UNIPT4IB4W3rNOqVO0rlqHmCIbSwxh/c9yUQ==",
"hasInstallScript": true,
"license": "MIT",
"bin": {
"esbuild": "bin/esbuild"
},
"engines": {
"node": ">=18"
},
"optionalDependencies": {
"@esbuild/aix-ppc64": "0.25.5",
"@esbuild/android-arm": "0.25.5",
"@esbuild/android-arm64": "0.25.5",
"@esbuild/android-x64": "0.25.5",
"@esbuild/darwin-arm64": "0.25.5",
"@esbuild/darwin-x64": "0.25.5",
"@esbuild/freebsd-arm64": "0.25.5",
"@esbuild/freebsd-x64": "0.25.5",
"@esbuild/linux-arm": "0.25.5",
"@esbuild/linux-arm64": "0.25.5",
"@esbuild/linux-ia32": "0.25.5",
"@esbuild/linux-loong64": "0.25.5",
"@esbuild/linux-mips64el": "0.25.5",
"@esbuild/linux-ppc64": "0.25.5",
"@esbuild/linux-riscv64": "0.25.5",
"@esbuild/linux-s390x": "0.25.5",
"@esbuild/linux-x64": "0.25.5",
"@esbuild/netbsd-arm64": "0.25.5",
"@esbuild/netbsd-x64": "0.25.5",
"@esbuild/openbsd-arm64": "0.25.5",
"@esbuild/openbsd-x64": "0.25.5",
"@esbuild/sunos-x64": "0.25.5",
"@esbuild/win32-arm64": "0.25.5",
"@esbuild/win32-ia32": "0.25.5",
"@esbuild/win32-x64": "0.25.5"
}
},
"node_modules/escalade": { "node_modules/escalade": {
"version": "3.2.0", "version": "3.2.0",
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
@ -6993,6 +6553,7 @@
"version": "2.3.3", "version": "2.3.3",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
"integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
"dev": true,
"hasInstallScript": true, "hasInstallScript": true,
"license": "MIT", "license": "MIT",
"optional": true, "optional": true,
@ -7154,6 +6715,7 @@
"version": "4.10.0", "version": "4.10.0",
"resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz", "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz",
"integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==", "integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==",
"dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"resolve-pkg-maps": "^1.0.0" "resolve-pkg-maps": "^1.0.0"
@ -9537,9 +9099,10 @@
"license": "MIT" "license": "MIT"
}, },
"node_modules/llama-stack-client": { "node_modules/llama-stack-client": {
"version": "0.2.13", "version": "0.2.14",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.13.tgz", "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz",
"integrity": "sha512-R1rTFLwgUimr+KjEUkzUvFL6vLASwS9qj3UDSVkJ5BmrKAs5GwVAMeL7yZaTBXGuPUVh124WSlC4d9H0FjWqLA==", "integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==",
"license": "Apache-2.0",
"dependencies": { "dependencies": {
"@types/node": "^18.11.18", "@types/node": "^18.11.18",
"@types/node-fetch": "^2.6.4", "@types/node-fetch": "^2.6.4",
@ -9547,8 +9110,7 @@
"agentkeepalive": "^4.2.1", "agentkeepalive": "^4.2.1",
"form-data-encoder": "1.7.2", "form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2", "formdata-node": "^4.3.2",
"node-fetch": "^2.6.7", "node-fetch": "^2.6.7"
"tsx": "^4.19.2"
} }
}, },
"node_modules/llama-stack-client/node_modules/@types/node": { "node_modules/llama-stack-client/node_modules/@types/node": {
@ -11148,6 +10710,7 @@
"version": "1.0.0", "version": "1.0.0",
"resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz", "resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz",
"integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==", "integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==",
"dev": true,
"license": "MIT", "license": "MIT",
"funding": { "funding": {
"url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1"
@ -12198,25 +11761,6 @@
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
"license": "0BSD" "license": "0BSD"
}, },
"node_modules/tsx": {
"version": "4.19.4",
"resolved": "https://registry.npmjs.org/tsx/-/tsx-4.19.4.tgz",
"integrity": "sha512-gK5GVzDkJK1SI1zwHf32Mqxf2tSJkNx+eYcNly5+nHvWqXUJYUkWBQtKauoESz3ymezAI++ZwT855x5p5eop+Q==",
"license": "MIT",
"dependencies": {
"esbuild": "~0.25.0",
"get-tsconfig": "^4.7.5"
},
"bin": {
"tsx": "dist/cli.mjs"
},
"engines": {
"node": ">=18.0.0"
},
"optionalDependencies": {
"fsevents": "~2.3.3"
}
},
"node_modules/tw-animate-css": { "node_modules/tw-animate-css": {
"version": "1.2.9", "version": "1.2.9",
"resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz", "resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz",

View file

@ -20,7 +20,7 @@
"@radix-ui/react-tooltip": "^1.2.6", "@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"llama-stack-client": "0.2.13", "llama-stack-client": "^0.2.14",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",

View file

@ -64,6 +64,7 @@ dev = [
"pytest-cov", "pytest-cov",
"pytest-html", "pytest-html",
"pytest-json-report", "pytest-json-report",
"pytest-socket", # For blocking network access in unit tests
"nbval", # For notebook testing "nbval", # For notebook testing
"black", "black",
"ruff", "ruff",
@ -344,3 +345,6 @@ classmethod-decorators = ["classmethod", "pydantic.field_validator"]
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "auto" asyncio_mode = "auto"
markers = [
"allow_network: Allow network access for specific unit tests",
]

View file

@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
return agent_config return agent_config
@pytest.fixture(scope="session")
def agent_config_without_safety(text_model_id):
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 0.0001,
"top_p": 0.9,
},
},
tools=[],
enable_session_persistence=False,
)
return agent_config
def test_agent_simple(llama_stack_client, agent_config): def test_agent_simple(llama_stack_client, agent_config):
agent = Agent(llama_stack_client, **agent_config) agent = Agent(llama_stack_client, **agent_config)
session_id = agent.create_session(f"test-session-{uuid4()}") session_id = agent.create_session(f"test-session-{uuid4()}")
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
assert expected_kw in response.output_message.content.lower() assert expected_kw in response.output_message.content.lower()
def test_rag_agent_with_attachments(llama_stack_client, agent_config): def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
urls = ["llama3.rst", "lora_finetune.rst"] urls = ["llama3.rst", "lora_finetune.rst"]
documents = [ documents = [
# passign as url # passign as url
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
metadata={}, metadata={},
), ),
] ]
rag_agent = Agent(llama_stack_client, **agent_config) rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
session_id = rag_agent.create_session(f"test-session-{uuid4()}") session_id = rag_agent.create_session(f"test-session-{uuid4()}")
user_prompts = [
(
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
"grouped",
),
]
user_prompts = [ user_prompts = [
( (
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.", "I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
assert "lora" in response.output_message.content.lower() assert "lora" in response.output_message.content.lower()
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
def test_rag_and_code_agent(llama_stack_client, agent_config):
if "llama-4" in agent_config["model"].lower():
pytest.xfail("Not working for llama4")
documents = []
documents.append(
Document(
document_id="nba_wiki",
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
metadata={},
)
)
documents.append(
Document(
document_id="perplexity_wiki",
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
Srinivas, the CEO, worked at OpenAI as an AI researcher.
Konwinski was among the founding team at Databricks.
Yarats, the CTO, was an AI research scientist at Meta.
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
metadata={},
)
)
vector_db_id = f"test-vector-db-{uuid4()}"
llama_stack_client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=vector_db_id,
chunk_size_in_tokens=128,
)
agent_config = {
**agent_config,
"tools": [
dict(
name="builtin::rag/knowledge_search",
args={"vector_db_ids": [vector_db_id]},
),
"builtin::code_interpreter",
],
}
agent = Agent(llama_stack_client, **agent_config)
user_prompts = [
(
"when was Perplexity the company founded?",
[],
"knowledge_search",
"2022",
),
(
"when was the nba created?",
[],
"knowledge_search",
"1949",
),
]
for prompt, docs, tool_name, expected_kw in user_prompts:
session_id = agent.create_session(f"test-session-{uuid4()}")
response = agent.create_turn(
messages=[{"role": "user", "content": prompt}],
session_id=session_id,
documents=docs,
stream=False,
)
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
if expected_kw:
assert expected_kw in response.output_message.content.lower()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"client_tools", "client_tools",
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)], [(get_boiling_point, False), (get_boiling_point_with_metadata, True)],

View file

@ -4,6 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest_socket
# We need to import the fixtures here so that pytest can find them # We need to import the fixtures here so that pytest can find them
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed # but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401 from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
def pytest_runtest_setup(item):
"""Setup for each test - check if network access should be allowed."""
if "allow_network" in item.keywords:
pytest_socket.enable_socket()
else:
# Allowing Unix sockets is necessary for some tests that use local servers and mocks
pytest_socket.disable_socket(allow_unix_socket=True)

View file

@ -393,6 +393,7 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
assert chunks[0].event.event_type.value == "start" assert chunks[0].event.event_type.value == "start"
@pytest.mark.allow_network
def test_chat_completion_doesnt_block_event_loop(caplog): def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
loop.set_debug(True) loop.set_debug(True)

View file

@ -87,6 +87,37 @@ def helper(known_provider_model: ProviderModelEntry, known_provider_model2: Prov
return ModelRegistryHelper([known_provider_model, known_provider_model2]) return ModelRegistryHelper([known_provider_model, known_provider_model2])
class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
"""Test helper that simulates a provider with dynamically available models."""
def __init__(self, model_entries: list[ProviderModelEntry], available_models: list[str]):
super().__init__(model_entries)
self._available_models = available_models
async def check_model_availability(self, model: str) -> bool:
return model in self._available_models
@pytest.fixture
def dynamic_model() -> Model:
"""A model that's not in static config but available dynamically."""
return Model(
provider_id="provider",
identifier="dynamic-model",
provider_resource_id="dynamic-provider-id",
)
@pytest.fixture
def helper_with_dynamic_models(
known_provider_model: ProviderModelEntry, known_provider_model2: ProviderModelEntry, dynamic_model: Model
) -> MockModelRegistryHelperWithDynamicModels:
"""Helper that includes dynamically available models."""
return MockModelRegistryHelperWithDynamicModels(
[known_provider_model, known_provider_model2], [dynamic_model.provider_resource_id]
)
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
assert helper.get_provider_model_id(unknown_model.model_id) is None assert helper.get_provider_model_id(unknown_model.model_id) is None
@ -151,3 +182,63 @@ async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_m
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
await helper.unregister_model(known_model.provider_resource_id) await helper.unregister_model(known_model.provider_resource_id)
assert helper.get_provider_model_id(known_model.provider_resource_id) is None assert helper.get_provider_model_id(known_model.provider_resource_id) is None
async def test_register_model_from_check_model_availability(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None:
"""Test that models returned by check_model_availability can be registered."""
# Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
# But it should be available via check_model_availability
is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
assert is_available
# Registration should succeed
registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
assert registered_model == dynamic_model
# Model should now be registered and accessible
assert (
helper_with_dynamic_models.get_provider_model_id(dynamic_model.model_id) == dynamic_model.provider_resource_id
)
async def test_register_model_not_in_static_or_dynamic(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model
) -> None:
"""Test that models not in static config or dynamic models are rejected."""
# Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
# And not available via check_model_availability
is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
assert not is_available
# Registration should fail with comprehensive error message
with pytest.raises(Exception) as exc_info: # UnsupportedModelError
await helper_with_dynamic_models.register_model(unknown_model)
# Error should include static models and "..." for dynamic models
error_str = str(exc_info.value)
assert "..." in error_str # "..." should be in error message
async def test_register_alias_for_dynamic_model(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None:
"""Test that we can register an alias that maps to a dynamically available model."""
# Create a model with a different identifier but same provider_resource_id
alias_model = Model(
provider_id=dynamic_model.provider_id,
identifier="dynamic-model-alias",
provider_resource_id=dynamic_model.provider_resource_id,
)
# Registration should succeed since the provider_resource_id is available dynamically
registered_model = await helper_with_dynamic_models.register_model(alias_model)
assert registered_model == alias_model
# Both the original provider_resource_id and the new alias should work
assert helper_with_dynamic_models.get_provider_model_id(alias_model.model_id) == dynamic_model.provider_resource_id

View file

@ -12,6 +12,8 @@ from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
@ -90,7 +92,7 @@ def sample_embeddings_with_metadata(sample_chunks_with_metadata):
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata]) return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
@pytest.fixture(params=["milvus", "sqlite_vec"]) @pytest.fixture(params=["milvus", "sqlite_vec", "faiss"])
def vector_provider(request): def vector_provider(request):
return request.param return request.param
@ -116,7 +118,7 @@ async def unique_kvstore_config(tmp_path_factory):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sqlite_vec_db_path(tmp_path_factory): def sqlite_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test.db") db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db")
return db_path return db_path
@ -198,11 +200,49 @@ async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
await adapter.shutdown() await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
return db_path
@pytest.fixture
async def faiss_vec_index(embedding_dimension):
index = FaissIndex(embedding_dimension)
yield index
await index.delete()
@pytest.fixture
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = FaissVectorIOConfig(
kvstore=unique_kvstore_config,
)
adapter = FaissVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture @pytest.fixture
def vector_io_adapter(vector_provider, request): def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter.""" """Returns the appropriate vector IO adapter based on the provider parameter."""
if vector_provider == "milvus": if vector_provider == "milvus":
return request.getfixturevalue("milvus_vec_adapter") return request.getfixturevalue("milvus_vec_adapter")
elif vector_provider == "faiss":
return request.getfixturevalue("faiss_vec_adapter")
else: else:
return request.getfixturevalue("sqlite_vec_adapter") return request.getfixturevalue("sqlite_vec_adapter")

View file

@ -0,0 +1,191 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import QueryChunksResponse
# Mock the entire pymilvus module
pymilvus_mock = MagicMock()
pymilvus_mock.DataType = MagicMock()
pymilvus_mock.MilvusClient = MagicMock
# Apply the mock before importing MilvusIndex
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
# tests which are specific to this class. More general (API-level) tests should be placed in
# tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_milvus.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
MILVUS_PROVIDER = "milvus"
@pytest_asyncio.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
# Mock collection operations
client.has_collection.return_value = False # Initially no collection
client.create_collection.return_value = None
client.drop_collection.return_value = None
# Mock insert operation
client.insert.return_value = {"insert_count": 10}
# Mock search operation - return mock results (data should be dict, not JSON string)
client.search.return_value = [
[
{
"id": 0,
"distance": 0.1,
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
},
{
"id": 1,
"distance": 0.2,
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
},
]
]
# Mock query operation for keyword search (data should be dict, not JSON string)
client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
"score": 0.9,
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
"score": 0.8,
},
{
"chunk_id": "chunk3",
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
"score": 0.7,
},
]
return client
@pytest_asyncio.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
yield index
# No real cleanup needed since we're using mocks
@pytest.mark.asyncio
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Verify collection was created and data was inserted
mock_milvus_client.create_collection.assert_called_once()
mock_milvus_client.insert.assert_called_once()
# Verify the insert call had the right number of chunks
insert_call = mock_milvus_client.insert.call_args
assert len(insert_call[1]["data"]) == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
# Setup: Add chunks first
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test vector search
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
mock_milvus_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Test keyword search
query_string = "Sentence 5"
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
# Force BM25 search to fail
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
# Mock simple text search results
mock_milvus_client.query.return_value = [
{
"chunk_id": "chunk1",
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
},
{
"chunk_id": "chunk2",
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
},
]
# Test keyword search that should fall back to simple text search
query_string = "Python"
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
# Verify response structure
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) > 0, "Fallback search should return results"
# Verify that simple text search was used (query method called instead of search)
mock_milvus_client.query.assert_called_once()
mock_milvus_client.search.assert_called_once() # Called once but failed
# Verify the query uses parameterized filter with filter_params
query_call_args = mock_milvus_client.query.call_args
assert "filter" in query_call_args[1], "Query should include filter for text search"
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
# Verify all returned chunks have score 1.0 (simple binary scoring)
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
@pytest.mark.asyncio
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
await milvus_index.delete()
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)

View file

@ -94,7 +94,7 @@ async def test_query_unregistered_raises(vector_io_adapter):
async def test_insert_chunks_calls_underlying_index(vector_io_adapter): async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
fake_index = AsyncMock() fake_index = AsyncMock()
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) vector_io_adapter.cache["db1"] = fake_index
chunks = ["chunk1", "chunk2"] chunks = ["chunk1", "chunk2"]
await vector_io_adapter.insert_chunks("db1", chunks) await vector_io_adapter.insert_chunks("db1", chunks)
@ -112,7 +112,7 @@ async def test_insert_chunks_missing_db_raises(vector_io_adapter):
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter): async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected)) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) vector_io_adapter.cache["db1"] = fake_index
response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1}) response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
@ -286,5 +286,7 @@ async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, t
await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
loaded_file_info = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_file_info == {}
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id) loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == [] assert loaded_contents == []

View file

@ -8,6 +8,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest import pytest
from llama_stack.apis.tools.rag_tool import RAGQueryConfig
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
Chunk, Chunk,
ChunkMetadata, ChunkMetadata,
@ -58,3 +59,14 @@ class TestRagQuery:
) )
assert expected_metadata_string in result.content[1].text assert expected_metadata_string in result.content[1].text
assert result.content is not None assert result.content is not None
async def test_query_raises_incorrect_mode(self):
with pytest.raises(ValueError):
RAGQueryConfig(mode="invalid_mode")
@pytest.mark.asyncio
async def test_query_accepts_valid_modes(self):
RAGQueryConfig() # Test default (vector)
RAGQueryConfig(mode="vector") # Test vector
RAGQueryConfig(mode="keyword") # Test keyword
RAGQueryConfig(mode="hybrid") # Test hybrid

View file

@ -123,6 +123,7 @@ class TestVectorStore:
content = await content_from_doc(doc) content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content(self): async def test_downloads_pdf_and_returns_content(self):
# Using GitHub to host the PDF file # Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"
@ -135,6 +136,7 @@ class TestVectorStore:
content = await content_from_doc(doc) content = await content_from_doc(doc)
assert content in DUMMY_PDF_TEXT_CHOICES assert content in DUMMY_PDF_TEXT_CHOICES
@pytest.mark.allow_network
async def test_downloads_pdf_and_returns_content_with_url_object(self): async def test_downloads_pdf_and_returns_content_with_url_object(self):
# Using GitHub to host the PDF file # Using GitHub to host the PDF file
url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf" url = "https://raw.githubusercontent.com/meta-llama/llama-stack/da035d69cfca915318eaf485770a467ca3c2a238/llama_stack/providers/tests/memory/fixtures/dummy.pdf"

14
uv.lock generated
View file

@ -1324,6 +1324,7 @@ dev = [
{ name = "pytest-cov" }, { name = "pytest-cov" },
{ name = "pytest-html" }, { name = "pytest-html" },
{ name = "pytest-json-report" }, { name = "pytest-json-report" },
{ name = "pytest-socket" },
{ name = "pytest-timeout" }, { name = "pytest-timeout" },
{ name = "ruamel-yaml" }, { name = "ruamel-yaml" },
{ name = "ruff" }, { name = "ruff" },
@ -1432,6 +1433,7 @@ dev = [
{ name = "pytest-cov" }, { name = "pytest-cov" },
{ name = "pytest-html" }, { name = "pytest-html" },
{ name = "pytest-json-report" }, { name = "pytest-json-report" },
{ name = "pytest-socket" },
{ name = "pytest-timeout" }, { name = "pytest-timeout" },
{ name = "ruamel-yaml" }, { name = "ruamel-yaml" },
{ name = "ruff" }, { name = "ruff" },
@ -2545,6 +2547,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" }, { url = "https://files.pythonhosted.org/packages/3e/43/7e7b2ec865caa92f67b8f0e9231a798d102724ca4c0e1f414316be1c1ef2/pytest_metadata-3.1.1-py3-none-any.whl", hash = "sha256:c8e0844db684ee1c798cfa38908d20d67d0463ecb6137c72e91f418558dd5f4b", size = 11428, upload-time = "2024-02-12T19:38:42.531Z" },
] ]
[[package]]
name = "pytest-socket"
version = "0.7.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
]
sdist = { url = "https://files.pythonhosted.org/packages/05/ff/90c7e1e746baf3d62ce864c479fd53410b534818b9437413903596f81580/pytest_socket-0.7.0.tar.gz", hash = "sha256:71ab048cbbcb085c15a4423b73b619a8b35d6a307f46f78ea46be51b1b7e11b3", size = 12389, upload-time = "2024-01-28T20:17:23.177Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/19/58/5d14cb5cb59409e491ebe816c47bf81423cd03098ea92281336320ae5681/pytest_socket-0.7.0-py3-none-any.whl", hash = "sha256:7e0f4642177d55d317bbd58fc68c6bd9048d6eadb2d46a89307fa9221336ce45", size = 6754, upload-time = "2024-01-28T20:17:22.105Z" },
]
[[package]] [[package]]
name = "pytest-timeout" name = "pytest-timeout"
version = "2.4.0" version = "2.4.0"