mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-23 07:02:28 +00:00
Merge branch 'main' into fix/issue-2584-llama4-tool-calling
This commit is contained in:
commit
5679d4dfd6
26 changed files with 669 additions and 507 deletions
30
.github/ISSUE_TEMPLATE/tech-debt.yml
vendored
Normal file
30
.github/ISSUE_TEMPLATE/tech-debt.yml
vendored
Normal file
|
|
@ -0,0 +1,30 @@
|
||||||
|
name: 🔧 Tech Debt
|
||||||
|
description: Something that is functional but should be improved or optimizied
|
||||||
|
labels: ["tech-debt"]
|
||||||
|
body:
|
||||||
|
- type: textarea
|
||||||
|
id: tech-debt-explanation
|
||||||
|
attributes:
|
||||||
|
label: 🤔 What is the technical debt you think should be addressed?
|
||||||
|
description: >
|
||||||
|
A clear and concise description of _what_ needs to be addressed - ensure you are describing
|
||||||
|
constitutes [technical debt](https://en.wikipedia.org/wiki/Technical_debt) and is not a bug
|
||||||
|
or feature request.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: tech-debt-motivation
|
||||||
|
attributes:
|
||||||
|
label: 💡 What is the benefit of addressing this technical debt?
|
||||||
|
description: >
|
||||||
|
A clear and concise description of _why_ this work is needed.
|
||||||
|
validations:
|
||||||
|
required: true
|
||||||
|
|
||||||
|
- type: textarea
|
||||||
|
id: other-thoughts
|
||||||
|
attributes:
|
||||||
|
label: Other thoughts
|
||||||
|
description: >
|
||||||
|
Any thoughts about how this may result in complexity in the codebase, or other trade-offs.
|
||||||
2
.github/workflows/integration-tests.yml
vendored
2
.github/workflows/integration-tests.yml
vendored
|
|
@ -89,7 +89,7 @@ jobs:
|
||||||
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
|
||||||
--text-model="ollama/llama3.2:3b-instruct-fp16" \
|
--text-model="ollama/llama3.2:3b-instruct-fp16" \
|
||||||
--embedding-model=all-MiniLM-L6-v2 \
|
--embedding-model=all-MiniLM-L6-v2 \
|
||||||
--safety-shield=ollama \
|
--safety-shield=$SAFETY_MODEL \
|
||||||
--color=yes \
|
--color=yes \
|
||||||
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
|
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log
|
||||||
|
|
||||||
|
|
|
||||||
69
.github/workflows/tests.yml
vendored
69
.github/workflows/tests.yml
vendored
|
|
@ -1,69 +0,0 @@
|
||||||
name: auto-tests
|
|
||||||
|
|
||||||
on:
|
|
||||||
# pull_request:
|
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
commit_sha:
|
|
||||||
description: 'Specific Commit SHA to trigger on'
|
|
||||||
required: false
|
|
||||||
default: $GITHUB_SHA # default to the last commit of $GITHUB_REF branch
|
|
||||||
|
|
||||||
jobs:
|
|
||||||
test-llama-stack-as-library:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
env:
|
|
||||||
TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }}
|
|
||||||
FIREWORKS_API_KEY: ${{ secrets.FIREWORKS_API_KEY }}
|
|
||||||
TAVILY_SEARCH_API_KEY: ${{ secrets.TAVILY_SEARCH_API_KEY }}
|
|
||||||
strategy:
|
|
||||||
matrix:
|
|
||||||
provider: [fireworks, together]
|
|
||||||
steps:
|
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
|
||||||
with:
|
|
||||||
ref: ${{ github.event.inputs.commit_sha }}
|
|
||||||
|
|
||||||
- name: Echo commit SHA
|
|
||||||
run: |
|
|
||||||
echo "Triggered on commit SHA: ${{ github.event.inputs.commit_sha }}"
|
|
||||||
git rev-parse HEAD
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
python -m pip install --upgrade pip
|
|
||||||
pip install -r requirements.txt pytest
|
|
||||||
pip install -e .
|
|
||||||
|
|
||||||
- name: Build providers
|
|
||||||
run: |
|
|
||||||
llama stack build --template ${{ matrix.provider }} --image-type venv
|
|
||||||
|
|
||||||
- name: Install the latest llama-stack-client & llama-models packages
|
|
||||||
run: |
|
|
||||||
pip install -e git+https://github.com/meta-llama/llama-stack-client-python.git#egg=llama-stack-client
|
|
||||||
pip install -e git+https://github.com/meta-llama/llama-models.git#egg=llama-models
|
|
||||||
|
|
||||||
- name: Run client-sdk test
|
|
||||||
working-directory: "${{ github.workspace }}"
|
|
||||||
env:
|
|
||||||
REPORT_OUTPUT: md_report.md
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
pip install --upgrade pytest-md-report
|
|
||||||
echo "REPORT_FILE=${REPORT_OUTPUT}" >> "$GITHUB_ENV"
|
|
||||||
|
|
||||||
export INFERENCE_MODEL=meta-llama/Llama-3.1-8B-Instruct
|
|
||||||
LLAMA_STACK_CONFIG=./llama_stack/templates/${{ matrix.provider }}/run.yaml pytest --md-report --md-report-verbose=1 ./tests/client-sdk/inference/ --md-report-output "$REPORT_OUTPUT"
|
|
||||||
|
|
||||||
- name: Output reports to the job summary
|
|
||||||
if: always()
|
|
||||||
shell: bash
|
|
||||||
run: |
|
|
||||||
if [ -f "$REPORT_FILE" ]; then
|
|
||||||
echo "<details><summary> Test Report for ${{ matrix.provider }} </summary>" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "" >> $GITHUB_STEP_SUMMARY
|
|
||||||
cat "$REPORT_FILE" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "" >> $GITHUB_STEP_SUMMARY
|
|
||||||
echo "</details>" >> $GITHUB_STEP_SUMMARY
|
|
||||||
fi
|
|
||||||
|
|
@ -77,7 +77,7 @@ ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template
|
||||||
You can use a container image to run the Llama Stack server. We provide several container images for the server
|
You can use a container image to run the Llama Stack server. We provide several container images for the server
|
||||||
component that works with different inference providers out of the box. For this guide, we will use
|
component that works with different inference providers out of the box. For this guide, we will use
|
||||||
`llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the
|
`llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the
|
||||||
configurations, please check out [this guide](../references/index.md).
|
configurations, please check out [this guide](../distributions/building_distro.md).
|
||||||
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
||||||
```bash
|
```bash
|
||||||
export INFERENCE_MODEL="llama3.2:3b"
|
export INFERENCE_MODEL="llama3.2:3b"
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
| `uri` | `<class 'str'>` | No | PydanticUndefined | The URI of the Milvus server |
|
||||||
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
| `token` | `str \| None` | No | PydanticUndefined | The token of the Milvus server |
|
||||||
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
|
||||||
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend |
|
||||||
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
| `config` | `dict` | No | {} | This configuration allows additional fields to be passed through to the underlying Milvus client. See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general. |
|
||||||
|
|
||||||
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
> **Note**: This configuration class accepts additional fields beyond those listed above. You can pass any additional configuration options that will be forwarded to the underlying provider.
|
||||||
|
|
@ -124,6 +124,9 @@ For more details on TLS configuration, refer to the [TLS setup guide](https://mi
|
||||||
```yaml
|
```yaml
|
||||||
uri: ${env.MILVUS_ENDPOINT}
|
uri: ${env.MILVUS_ENDPOINT}
|
||||||
token: ${env.MILVUS_TOKEN}
|
token: ${env.MILVUS_TOKEN}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/milvus_remote_registry.db
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
||||||
| `db` | `str \| None` | No | postgres | |
|
| `db` | `str \| None` | No | postgres | |
|
||||||
| `user` | `str \| None` | No | postgres | |
|
| `user` | `str \| None` | No | postgres | |
|
||||||
| `password` | `str \| None` | No | mysecretpassword | |
|
| `password` | `str \| None` | No | mysecretpassword | |
|
||||||
|
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
|
|
@ -49,6 +50,9 @@ port: ${env.PGVECTOR_PORT:=5432}
|
||||||
db: ${env.PGVECTOR_DB}
|
db: ${env.PGVECTOR_DB}
|
||||||
user: ${env.PGVECTOR_USER}
|
user: ${env.PGVECTOR_USER}
|
||||||
password: ${env.PGVECTOR_PASSWORD}
|
password: ${env.PGVECTOR_PASSWORD}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/pgvector_registry.db
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,9 @@ See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
{}
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -181,8 +181,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = index
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
||||||
# Load existing OpenAI vector stores using the mixin method
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# Cleanup if needed
|
# Cleanup if needed
|
||||||
|
|
@ -261,42 +261,6 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
# OpenAI Vector Store Mixin abstract method implementations
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Save vector store metadata to kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Load all vector store metadata from kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
|
|
||||||
stores = {}
|
|
||||||
for store_data in stored_openai_stores:
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
stores[store_info["id"]] = store_info
|
|
||||||
return stores
|
|
||||||
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store metadata in kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
|
||||||
"""Delete vector store metadata from kvstore."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
|
||||||
if store_id in self.openai_vector_stores:
|
|
||||||
del self.openai_vector_stores[store_id]
|
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -452,8 +452,8 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
)
|
)
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
|
|
||||||
# load any existing OpenAI vector stores
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
# nothing to do since we don't maintain a persistent connection
|
# nothing to do since we don't maintain a persistent connection
|
||||||
|
|
@ -501,41 +501,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
# OpenAI Vector Store Mixin abstract method implementations
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Save vector store metadata to SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Load all vector store metadata from SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
stores = {}
|
|
||||||
for store_data in stored_openai_stores:
|
|
||||||
store_info = json.loads(store_data)
|
|
||||||
stores[store_info["id"]] = store_info
|
|
||||||
return stores
|
|
||||||
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store metadata in SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
|
||||||
"""Delete vector store metadata from SQLite database."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
|
||||||
if store_id in self.openai_vector_stores:
|
|
||||||
del self.openai_vector_stores[store_id]
|
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,19 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_model_entry,
|
build_model_entry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
SAFETY_MODELS_ENTRIES = [
|
||||||
|
# The Llama Guard models don't have their full fp16 versions
|
||||||
|
# so we are going to alias their default version to the canonical SKU
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"llama-guard3:8b",
|
||||||
|
CoreModelId.llama_guard_3_8b.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"llama-guard3:1b",
|
||||||
|
CoreModelId.llama_guard_3_1b.value,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
MODEL_ENTRIES = [
|
MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"llama3.1:8b-instruct-fp16",
|
"llama3.1:8b-instruct-fp16",
|
||||||
|
|
@ -73,16 +86,6 @@ MODEL_ENTRIES = [
|
||||||
"llama3.3:70b",
|
"llama3.3:70b",
|
||||||
CoreModelId.llama3_3_70b_instruct.value,
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
),
|
),
|
||||||
# The Llama Guard models don't have their full fp16 versions
|
|
||||||
# so we are going to alias their default version to the canonical SKU
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"llama-guard3:8b",
|
|
||||||
CoreModelId.llama_guard_3_8b.value,
|
|
||||||
),
|
|
||||||
build_hf_repo_model_entry(
|
|
||||||
"llama-guard3:1b",
|
|
||||||
CoreModelId.llama_guard_3_1b.value,
|
|
||||||
),
|
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="all-minilm:l6-v2",
|
provider_model_id="all-minilm:l6-v2",
|
||||||
aliases=["all-minilm"],
|
aliases=["all-minilm"],
|
||||||
|
|
@ -100,4 +103,4 @@ MODEL_ENTRIES = [
|
||||||
"context_length": 8192,
|
"context_length": 8192,
|
||||||
},
|
},
|
||||||
),
|
),
|
||||||
]
|
] + SAFETY_MODELS_ENTRIES
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
|
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -17,7 +17,7 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
uri: str = Field(description="The URI of the Milvus server")
|
uri: str = Field(description="The URI of the Milvus server")
|
||||||
token: str | None = Field(description="The token of the Milvus server")
|
token: str | None = Field(description="The token of the Milvus server")
|
||||||
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
|
||||||
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
kvstore: KVStoreConfig = Field(description="Config for KV store backend")
|
||||||
|
|
||||||
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
# This configuration allows additional fields to be passed through to the underlying Milvus client.
|
||||||
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
# See the [Milvus](https://milvus.io/docs/install-overview.md) documentation for more details about Milvus in general.
|
||||||
|
|
@ -25,4 +25,11 @@ class MilvusVectorIOConfig(BaseModel):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {"uri": "${env.MILVUS_ENDPOINT}", "token": "${env.MILVUS_TOKEN}"}
|
return {
|
||||||
|
"uri": "${env.MILVUS_ENDPOINT}",
|
||||||
|
"token": "${env.MILVUS_TOKEN}",
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="milvus_remote_registry.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ import re
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from pymilvus import DataType, MilvusClient
|
from pymilvus import DataType, Function, FunctionType, MilvusClient
|
||||||
|
|
||||||
from llama_stack.apis.files.files import Files
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
|
@ -74,12 +74,66 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||||
|
logger.info(f"Creating new collection {self.collection_name} with nullable sparse field")
|
||||||
|
# Create schema for vector search
|
||||||
|
schema = self.client.create_schema()
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_id",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
is_primary=True,
|
||||||
|
max_length=100,
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="content",
|
||||||
|
datatype=DataType.VARCHAR,
|
||||||
|
max_length=65535,
|
||||||
|
enable_analyzer=True, # Enable text analysis for BM25
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="vector",
|
||||||
|
datatype=DataType.FLOAT_VECTOR,
|
||||||
|
dim=len(embeddings[0]),
|
||||||
|
)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="chunk_content",
|
||||||
|
datatype=DataType.JSON,
|
||||||
|
)
|
||||||
|
# Add sparse vector field for BM25 (required by the function)
|
||||||
|
schema.add_field(
|
||||||
|
field_name="sparse",
|
||||||
|
datatype=DataType.SPARSE_FLOAT_VECTOR,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create indexes
|
||||||
|
index_params = self.client.prepare_index_params()
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="vector",
|
||||||
|
index_type="FLAT",
|
||||||
|
metric_type="COSINE",
|
||||||
|
)
|
||||||
|
# Add index for sparse field (required by BM25 function)
|
||||||
|
index_params.add_index(
|
||||||
|
field_name="sparse",
|
||||||
|
index_type="SPARSE_INVERTED_INDEX",
|
||||||
|
metric_type="BM25",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add BM25 function for full-text search
|
||||||
|
bm25_function = Function(
|
||||||
|
name="text_bm25_emb",
|
||||||
|
input_field_names=["content"],
|
||||||
|
output_field_names=["sparse"],
|
||||||
|
function_type=FunctionType.BM25,
|
||||||
|
)
|
||||||
|
schema.add_function(bm25_function)
|
||||||
|
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self.client.create_collection,
|
self.client.create_collection,
|
||||||
self.collection_name,
|
self.collection_name,
|
||||||
dimension=len(embeddings[0]),
|
schema=schema,
|
||||||
auto_id=True,
|
index_params=index_params,
|
||||||
consistency_level=self.consistency_level,
|
consistency_level=self.consistency_level,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -88,8 +142,10 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
data.append(
|
data.append(
|
||||||
{
|
{
|
||||||
"chunk_id": chunk.chunk_id,
|
"chunk_id": chunk.chunk_id,
|
||||||
|
"content": chunk.content,
|
||||||
"vector": embedding,
|
"vector": embedding,
|
||||||
"chunk_content": chunk.model_dump(),
|
"chunk_content": chunk.model_dump(),
|
||||||
|
# sparse field will be handled by BM25 function automatically
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
|
|
@ -107,6 +163,7 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
self.client.search,
|
self.client.search,
|
||||||
collection_name=self.collection_name,
|
collection_name=self.collection_name,
|
||||||
data=[embedding],
|
data=[embedding],
|
||||||
|
anns_field="vector",
|
||||||
limit=k,
|
limit=k,
|
||||||
output_fields=["*"],
|
output_fields=["*"],
|
||||||
search_params={"params": {"radius": score_threshold}},
|
search_params={"params": {"radius": score_threshold}},
|
||||||
|
|
@ -121,7 +178,64 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
k: int,
|
k: int,
|
||||||
score_threshold: float,
|
score_threshold: float,
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Milvus")
|
"""
|
||||||
|
Perform BM25-based keyword search using Milvus's built-in full-text search.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Use Milvus's built-in BM25 search
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.search,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
data=[query_string], # Raw text query
|
||||||
|
anns_field="sparse", # Use sparse field for BM25
|
||||||
|
output_fields=["chunk_content"], # Output the chunk content
|
||||||
|
limit=k,
|
||||||
|
search_params={
|
||||||
|
"params": {
|
||||||
|
"drop_ratio_search": 0.2, # Ignore low-importance terms
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
for res in search_res[0]:
|
||||||
|
chunk = Chunk(**res["entity"]["chunk_content"])
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(res["distance"]) # BM25 score from Milvus
|
||||||
|
|
||||||
|
# Filter by score threshold
|
||||||
|
filtered_chunks = [chunk for chunk, score in zip(chunks, scores, strict=False) if score >= score_threshold]
|
||||||
|
filtered_scores = [score for score in scores if score >= score_threshold]
|
||||||
|
|
||||||
|
return QueryChunksResponse(chunks=filtered_chunks, scores=filtered_scores)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error performing BM25 search: {e}")
|
||||||
|
# Fallback to simple text search
|
||||||
|
return await self._fallback_keyword_search(query_string, k, score_threshold)
|
||||||
|
|
||||||
|
async def _fallback_keyword_search(
|
||||||
|
self,
|
||||||
|
query_string: str,
|
||||||
|
k: int,
|
||||||
|
score_threshold: float,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
Fallback to simple text search when BM25 search is not available.
|
||||||
|
"""
|
||||||
|
# Simple text search using content field
|
||||||
|
search_res = await asyncio.to_thread(
|
||||||
|
self.client.query,
|
||||||
|
collection_name=self.collection_name,
|
||||||
|
filter='content like "%{content}%"',
|
||||||
|
filter_params={"content": query_string},
|
||||||
|
output_fields=["*"],
|
||||||
|
limit=k,
|
||||||
|
)
|
||||||
|
chunks = [Chunk(**res["chunk_content"]) for res in search_res]
|
||||||
|
scores = [1.0] * len(chunks) # Simple binary score for text search
|
||||||
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
|
|
@ -179,7 +293,8 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
uri = os.path.expanduser(self.config.db_path)
|
uri = os.path.expanduser(self.config.db_path)
|
||||||
self.client = MilvusClient(uri=uri)
|
self.client = MilvusClient(uri=uri)
|
||||||
|
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
self.client.close()
|
self.client.close()
|
||||||
|
|
@ -246,38 +361,16 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
|
if params and params.get("mode") == "keyword":
|
||||||
|
# Check if this is inline Milvus (Milvus-Lite)
|
||||||
|
if hasattr(self.config, "db_path"):
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Keyword search is not supported in Milvus-Lite. "
|
||||||
|
"Please use a remote Milvus server for keyword search functionality."
|
||||||
|
)
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Save vector store metadata to persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
|
||||||
"""Update vector store metadata in persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
|
||||||
self.openai_vector_stores[store_id] = store_info
|
|
||||||
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
|
||||||
"""Delete vector store metadata from persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
|
||||||
await self.kvstore.delete(key)
|
|
||||||
if store_id in self.openai_vector_stores:
|
|
||||||
del self.openai_vector_stores[store_id]
|
|
||||||
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Load all vector store metadata from persistent storage."""
|
|
||||||
assert self.kvstore is not None
|
|
||||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
|
||||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
|
||||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
|
||||||
return {json.loads(s)["id"]: json.loads(s) for s in stored}
|
|
||||||
|
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,10 @@ from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
from llama_stack.schema_utils import json_schema_type
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,10 +22,12 @@ class PGVectorVectorIOConfig(BaseModel):
|
||||||
db: str | None = Field(default="postgres")
|
db: str | None = Field(default="postgres")
|
||||||
user: str | None = Field(default="postgres")
|
user: str | None = Field(default="postgres")
|
||||||
password: str | None = Field(default="mysecretpassword")
|
password: str | None = Field(default="mysecretpassword")
|
||||||
|
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(
|
def sample_run_config(
|
||||||
cls,
|
cls,
|
||||||
|
__distro_dir__: str,
|
||||||
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
host: str = "${env.PGVECTOR_HOST:=localhost}",
|
||||||
port: int = "${env.PGVECTOR_PORT:=5432}",
|
port: int = "${env.PGVECTOR_PORT:=5432}",
|
||||||
db: str = "${env.PGVECTOR_DB}",
|
db: str = "${env.PGVECTOR_DB}",
|
||||||
|
|
@ -29,4 +35,14 @@ class PGVectorVectorIOConfig(BaseModel):
|
||||||
password: str = "${env.PGVECTOR_PASSWORD}",
|
password: str = "${env.PGVECTOR_PASSWORD}",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {"host": host, "port": port, "db": db, "user": user, "password": password}
|
return {
|
||||||
|
"host": host,
|
||||||
|
"port": port,
|
||||||
|
"db": db,
|
||||||
|
"user": user,
|
||||||
|
"password": password,
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="pgvector_registry.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -13,24 +13,18 @@ from psycopg2 import sql
|
||||||
from psycopg2.extras import Json, execute_values
|
from psycopg2.extras import Json, execute_values
|
||||||
from pydantic import BaseModel, TypeAdapter
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.inference import InterleavedContent
|
from llama_stack.apis.inference import InterleavedContent
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import (
|
from llama_stack.apis.vector_io import (
|
||||||
Chunk,
|
Chunk,
|
||||||
QueryChunksResponse,
|
QueryChunksResponse,
|
||||||
SearchRankingOptions,
|
|
||||||
VectorIO,
|
VectorIO,
|
||||||
VectorStoreChunkingStrategy,
|
|
||||||
VectorStoreDeleteResponse,
|
|
||||||
VectorStoreFileContentsResponse,
|
|
||||||
VectorStoreFileObject,
|
|
||||||
VectorStoreFileStatus,
|
|
||||||
VectorStoreListFilesResponse,
|
|
||||||
VectorStoreListResponse,
|
|
||||||
VectorStoreObject,
|
|
||||||
VectorStoreSearchResponsePage,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
|
@ -40,6 +34,13 @@ from .config import PGVectorVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VERSION = "v3"
|
||||||
|
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"
|
||||||
|
VECTOR_INDEX_PREFIX = f"vector_index:pgvector:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:pgvector:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:pgvector:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:pgvector:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
def check_extension_version(cur):
|
def check_extension_version(cur):
|
||||||
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
cur.execute("SELECT extversion FROM pg_extension WHERE extname = 'vector'")
|
||||||
|
|
@ -69,7 +70,7 @@ def load_models(cur, cls):
|
||||||
|
|
||||||
|
|
||||||
class PGVectorIndex(EmbeddingIndex):
|
class PGVectorIndex(EmbeddingIndex):
|
||||||
def __init__(self, vector_db: VectorDB, dimension: int, conn):
|
def __init__(self, vector_db: VectorDB, dimension: int, conn, kvstore: KVStore | None = None):
|
||||||
self.conn = conn
|
self.conn = conn
|
||||||
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
# Sanitize the table name by replacing hyphens with underscores
|
# Sanitize the table name by replacing hyphens with underscores
|
||||||
|
|
@ -77,6 +78,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
# when created with patterns like "test-vector-db-{uuid4()}"
|
# when created with patterns like "test-vector-db-{uuid4()}"
|
||||||
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
sanitized_identifier = vector_db.identifier.replace("-", "_")
|
||||||
self.table_name = f"vector_store_{sanitized_identifier}"
|
self.table_name = f"vector_store_{sanitized_identifier}"
|
||||||
|
self.kvstore = kvstore
|
||||||
|
|
||||||
cur.execute(
|
cur.execute(
|
||||||
f"""
|
f"""
|
||||||
|
|
@ -158,15 +160,28 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
|
||||||
class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PGVectorVectorIOConfig,
|
||||||
|
inference_api: Api.inference,
|
||||||
|
files_api: Files | None = None,
|
||||||
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.files_api = files_api
|
||||||
|
self.kvstore: KVStore | None = None
|
||||||
|
self.vector_db_store = None
|
||||||
|
self.openai_vector_store: dict[str, dict[str, Any]] = {}
|
||||||
|
self.metadatadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
log.info(f"Initializing PGVector memory adapter with config: {self.config}")
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.conn = psycopg2.connect(
|
self.conn = psycopg2.connect(
|
||||||
host=self.config.host,
|
host=self.config.host,
|
||||||
|
|
@ -201,15 +216,32 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
log.info("Connection to PGVector database server closed")
|
log.info("Connection to PGVector database server closed")
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
|
# Persist vector DB metadata in the KV store
|
||||||
|
assert self.kvstore is not None
|
||||||
|
key = f"{VECTOR_DBS_PREFIX}{vector_db.identifier}"
|
||||||
|
await self.kvstore.set(key=key, value=vector_db.model_dump_json())
|
||||||
|
|
||||||
|
# Upsert model metadata in Postgres
|
||||||
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
upsert_models(self.conn, [(vector_db.identifier, vector_db)])
|
||||||
|
|
||||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
# Create and cache the PGVector index table for the vector DB
|
||||||
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
index = VectorDBWithIndex(
|
||||||
|
vector_db,
|
||||||
|
index=PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn, kvstore=self.kvstore),
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
self.cache[vector_db.identifier] = index
|
||||||
|
|
||||||
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
async def unregister_vector_db(self, vector_db_id: str) -> None:
|
||||||
|
# Remove provider index and cache
|
||||||
|
if vector_db_id in self.cache:
|
||||||
await self.cache[vector_db_id].index.delete()
|
await self.cache[vector_db_id].index.delete()
|
||||||
del self.cache[vector_db_id]
|
del self.cache[vector_db_id]
|
||||||
|
|
||||||
|
# Delete vector DB metadata from KV store
|
||||||
|
assert self.kvstore is not None
|
||||||
|
await self.kvstore.delete(key=f"{VECTOR_DBS_PREFIX}{vector_db_id}")
|
||||||
|
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
|
|
@ -237,107 +269,20 @@ class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
return self.cache[vector_db_id]
|
return self.cache[vector_db_id]
|
||||||
|
|
||||||
async def openai_create_vector_store(
|
# OpenAI Vector Stores File operations are not supported in PGVector
|
||||||
self,
|
async def _save_openai_vector_store_file(
|
||||||
name: str,
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
file_ids: list[str] | None = None,
|
) -> None:
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
embedding_model: str | None = None,
|
|
||||||
embedding_dimension: int | None = 384,
|
|
||||||
provider_id: str | None = None,
|
|
||||||
provider_vector_db_id: str | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_list_vector_stores(
|
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||||
self,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: str | None = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
) -> VectorStoreListResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_retrieve_vector_store(
|
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_update_vector_store(
|
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
name: str | None = None,
|
|
||||||
expires_after: dict[str, Any] | None = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
||||||
async def openai_delete_vector_store(
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
) -> VectorStoreDeleteResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_search_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
query: str | list[str],
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
max_num_results: int | None = 10,
|
|
||||||
ranking_options: SearchRankingOptions | None = None,
|
|
||||||
rewrite_query: bool | None = False,
|
|
||||||
search_mode: str | None = "vector",
|
|
||||||
) -> VectorStoreSearchResponsePage:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_attach_file_to_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
attributes: dict[str, Any] | None = None,
|
|
||||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_list_files_in_vector_store(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
limit: int | None = 20,
|
|
||||||
order: str | None = "desc",
|
|
||||||
after: str | None = None,
|
|
||||||
before: str | None = None,
|
|
||||||
filter: VectorStoreFileStatus | None = None,
|
|
||||||
) -> VectorStoreListFilesResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store_file(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_retrieve_vector_store_file_contents(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
) -> VectorStoreFileContentsResponse:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_update_vector_store_file(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
attributes: dict[str, Any] | None = None,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
|
||||||
|
|
||||||
async def openai_delete_vector_store_file(
|
|
||||||
self,
|
|
||||||
vector_store_id: str,
|
|
||||||
file_id: str,
|
|
||||||
) -> VectorStoreFileObject:
|
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in PGVector")
|
||||||
|
|
|
||||||
|
|
@ -6,15 +6,26 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.kvstore.config import (
|
||||||
|
KVStoreConfig,
|
||||||
|
SqliteKVStoreConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateRequestProviderData(BaseModel):
|
class WeaviateRequestProviderData(BaseModel):
|
||||||
weaviate_api_key: str
|
weaviate_api_key: str
|
||||||
weaviate_cluster_url: str
|
weaviate_cluster_url: str
|
||||||
|
kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateVectorIOConfig(BaseModel):
|
class WeaviateVectorIOConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]:
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
return {}
|
return {
|
||||||
|
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=__distro_dir__,
|
||||||
|
db_name="weaviate_registry.db",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,13 @@ from weaviate.classes.init import Auth
|
||||||
from weaviate.classes.query import Filter
|
from weaviate.classes.query import Filter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
|
from llama_stack.apis.files.files import Files
|
||||||
from llama_stack.apis.vector_dbs import VectorDB
|
from llama_stack.apis.vector_dbs import VectorDB
|
||||||
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
VectorDBWithIndex,
|
VectorDBWithIndex,
|
||||||
|
|
@ -27,11 +30,19 @@ from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
VERSION = "v3"
|
||||||
|
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"
|
||||||
|
VECTOR_INDEX_PREFIX = f"vector_index:weaviate:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:weaviate:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:weaviate:{VERSION}::"
|
||||||
|
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:weaviate:{VERSION}::"
|
||||||
|
|
||||||
|
|
||||||
class WeaviateIndex(EmbeddingIndex):
|
class WeaviateIndex(EmbeddingIndex):
|
||||||
def __init__(self, client: weaviate.Client, collection_name: str):
|
def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.collection_name = collection_name
|
self.collection_name = collection_name
|
||||||
|
self.kvstore = kvstore
|
||||||
|
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
assert len(chunks) == len(embeddings), (
|
assert len(chunks) == len(embeddings), (
|
||||||
|
|
@ -109,11 +120,21 @@ class WeaviateVectorIOAdapter(
|
||||||
NeedsRequestProviderData,
|
NeedsRequestProviderData,
|
||||||
VectorDBsProtocolPrivate,
|
VectorDBsProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: WeaviateVectorIOConfig,
|
||||||
|
inference_api: Api.inference,
|
||||||
|
files_api: Files | None,
|
||||||
|
) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.files_api = files_api
|
||||||
|
self.kvstore: KVStore | None = None
|
||||||
|
self.vector_db_store = None
|
||||||
|
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||||
|
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||||
|
|
||||||
def _get_client(self) -> weaviate.Client:
|
def _get_client(self) -> weaviate.Client:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
|
|
@ -132,7 +153,26 @@ class WeaviateVectorIOAdapter(
|
||||||
return client
|
return client
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
"""Set up KV store and load existing vector DBs and OpenAI vector stores."""
|
||||||
|
# Initialize KV store for metadata
|
||||||
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||||
|
|
||||||
|
# Load existing vector DB definitions
|
||||||
|
start_key = VECTOR_DBS_PREFIX
|
||||||
|
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||||
|
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
for raw in stored:
|
||||||
|
vector_db = VectorDB.model_validate_json(raw)
|
||||||
|
client = self._get_client()
|
||||||
|
idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore)
|
||||||
|
self.cache[vector_db.identifier] = VectorDBWithIndex(
|
||||||
|
vector_db=vector_db,
|
||||||
|
index=idx,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load OpenAI vector stores metadata into cache
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
for client in self.client_cache.values():
|
for client in self.client_cache.values():
|
||||||
|
|
@ -206,3 +246,21 @@ class WeaviateVectorIOAdapter(
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
# OpenAI Vector Stores File operations are not supported in Weaviate
|
||||||
|
async def _save_openai_vector_store_file(
|
||||||
|
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||||
|
) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import time
|
import time
|
||||||
|
|
@ -35,6 +36,7 @@ from llama_stack.apis.vector_io import (
|
||||||
VectorStoreSearchResponse,
|
VectorStoreSearchResponse,
|
||||||
VectorStoreSearchResponsePage,
|
VectorStoreSearchResponsePage,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||||
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, make_overlapped_chunks
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -59,26 +61,45 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
# These should be provided by the implementing class
|
# These should be provided by the implementing class
|
||||||
openai_vector_stores: dict[str, dict[str, Any]]
|
openai_vector_stores: dict[str, dict[str, Any]]
|
||||||
files_api: Files | None
|
files_api: Files | None
|
||||||
|
# KV store for persisting OpenAI vector store metadata
|
||||||
|
kvstore: KVStore | None
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
"""Save vector store metadata to persistent storage."""
|
"""Save vector store metadata to persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
# update in-memory cache
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||||
"""Load all vector store metadata from persistent storage."""
|
"""Load all vector store metadata from persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||||
|
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||||
|
stored_data = await self.kvstore.values_in_range(start_key, end_key)
|
||||||
|
|
||||||
|
stores: dict[str, dict[str, Any]] = {}
|
||||||
|
for item in stored_data:
|
||||||
|
info = json.loads(item)
|
||||||
|
stores[info["id"]] = info
|
||||||
|
return stores
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||||
"""Update vector store metadata in persistent storage."""
|
"""Update vector store metadata in persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||||
|
# update in-memory cache
|
||||||
|
self.openai_vector_stores[store_id] = store_info
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||||
"""Delete vector store metadata from persistent storage."""
|
"""Delete vector store metadata from persistent storage."""
|
||||||
pass
|
assert self.kvstore is not None
|
||||||
|
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||||
|
await self.kvstore.delete(key)
|
||||||
|
# remove from in-memory cache
|
||||||
|
self.openai_vector_stores.pop(store_id, None)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def _save_openai_vector_store_file(
|
async def _save_openai_vector_store_file(
|
||||||
|
|
@ -117,6 +138,10 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"""Unregister a vector database (provider-specific implementation)."""
|
"""Unregister a vector database (provider-specific implementation)."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def initialize_openai_vector_stores(self) -> None:
|
||||||
|
"""Load existing OpenAI vector stores into the in-memory cache."""
|
||||||
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def insert_chunks(
|
async def insert_chunks(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, _ = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="nvidia",
|
name="nvidia",
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
|
|
|
||||||
|
|
@ -128,6 +128,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
|
provider_id="${env.ENABLE_PGVECTOR:+pgvector}",
|
||||||
provider_type="remote::pgvector",
|
provider_type="remote::pgvector",
|
||||||
config=PGVectorVectorIOConfig.sample_run_config(
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
db="${env.PGVECTOR_DB:=}",
|
db="${env.PGVECTOR_DB:=}",
|
||||||
user="${env.PGVECTOR_USER:=}",
|
user="${env.PGVECTOR_USER:=}",
|
||||||
password="${env.PGVECTOR_PASSWORD:=}",
|
password="${env.PGVECTOR_PASSWORD:=}",
|
||||||
|
|
@ -146,7 +147,8 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
default_models = get_model_registry(available_models) + [
|
models, _ = get_model_registry(available_models)
|
||||||
|
default_models = models + [
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
model_id="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
provider_id="groq",
|
provider_id="groq",
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,9 @@ providers:
|
||||||
db: ${env.PGVECTOR_DB:=}
|
db: ${env.PGVECTOR_DB:=}
|
||||||
user: ${env.PGVECTOR_USER:=}
|
user: ${env.PGVECTOR_USER:=}
|
||||||
password: ${env.PGVECTOR_PASSWORD:=}
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/pgvector_registry.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
||||||
|
|
@ -166,6 +166,9 @@ providers:
|
||||||
db: ${env.PGVECTOR_DB:=}
|
db: ${env.PGVECTOR_DB:=}
|
||||||
user: ${env.PGVECTOR_USER:=}
|
user: ${env.PGVECTOR_USER:=}
|
||||||
password: ${env.PGVECTOR_PASSWORD:=}
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
@ -1171,24 +1174,8 @@ models:
|
||||||
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
|
||||||
model_type: embedding
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
|
- shield_id: ${env.SAFETY_MODEL:=__disabled__}
|
||||||
provider_id: llama-guard
|
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
|
||||||
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
|
|
||||||
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
|
|
||||||
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
|
|
||||||
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
|
|
||||||
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
|
|
||||||
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
|
|
||||||
provider_id: llama-guard
|
|
||||||
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
|
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,6 @@ from llama_stack.distribution.datatypes import (
|
||||||
ModelInput,
|
ModelInput,
|
||||||
Provider,
|
Provider,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
ShieldInput,
|
|
||||||
ToolGroupInput,
|
ToolGroupInput,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
|
@ -32,75 +31,39 @@ from llama_stack.providers.registry.inference import available_providers
|
||||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
from llama_stack.providers.remote.inference.anthropic.models import (
|
||||||
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.anthropic.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
from llama_stack.providers.remote.inference.bedrock.models import (
|
||||||
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.bedrock.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
from llama_stack.providers.remote.inference.cerebras.models import (
|
||||||
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.cerebras.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
from llama_stack.providers.remote.inference.databricks.databricks import (
|
||||||
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.databricks.databricks import (
|
|
||||||
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
from llama_stack.providers.remote.inference.fireworks.models import (
|
||||||
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.fireworks.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.gemini.models import (
|
from llama_stack.providers.remote.inference.gemini.models import (
|
||||||
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.gemini.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.groq.models import (
|
from llama_stack.providers.remote.inference.groq.models import (
|
||||||
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.groq.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
from llama_stack.providers.remote.inference.nvidia.models import (
|
||||||
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.openai.models import (
|
from llama_stack.providers.remote.inference.openai.models import (
|
||||||
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.openai.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
from llama_stack.providers.remote.inference.runpod.runpod import (
|
||||||
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.runpod.runpod import (
|
|
||||||
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
from llama_stack.providers.remote.inference.sambanova.models import (
|
||||||
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.sambanova.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.inference.together.models import (
|
from llama_stack.providers.remote.inference.together.models import (
|
||||||
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.remote.inference.together.models import (
|
|
||||||
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
|
||||||
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
from llama_stack.providers.remote.vector_io.pgvector.config import (
|
||||||
PGVectorVectorIOConfig,
|
PGVectorVectorIOConfig,
|
||||||
|
|
@ -111,6 +74,7 @@ from llama_stack.templates.template import (
|
||||||
DistributionTemplate,
|
DistributionTemplate,
|
||||||
RunConfigSettings,
|
RunConfigSettings,
|
||||||
get_model_registry,
|
get_model_registry,
|
||||||
|
get_shield_registry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -164,28 +128,13 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
|
||||||
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
|
||||||
"""Get model entries for a specific provider type."""
|
"""Get model entries for a specific provider type."""
|
||||||
safety_model_entries_map = {
|
safety_model_entries_map = {
|
||||||
"openai": OPENAI_SAFETY_MODELS_ENTRIES,
|
"ollama": [
|
||||||
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
|
|
||||||
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
|
|
||||||
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
|
|
||||||
"groq": GROQ_SAFETY_MODELS_ENTRIES,
|
|
||||||
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
|
|
||||||
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
|
|
||||||
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
|
|
||||||
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
|
|
||||||
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Special handling for providers with dynamic model entries
|
|
||||||
if provider_type == "ollama":
|
|
||||||
return [
|
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="llama-guard3:1b",
|
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
),
|
),
|
||||||
]
|
],
|
||||||
|
}
|
||||||
|
|
||||||
return safety_model_entries_map.get(provider_type, [])
|
return safety_model_entries_map.get(provider_type, [])
|
||||||
|
|
||||||
|
|
@ -246,28 +195,20 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
|
||||||
|
|
||||||
|
|
||||||
# build a list of shields for all possible providers
|
# build a list of shields for all possible providers
|
||||||
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
|
def get_safety_models_for_providers(providers: list[Provider]) -> dict[str, list[ProviderModelEntry]]:
|
||||||
shields = []
|
available_models = {}
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
provider_type = provider.provider_type.split("::")[1]
|
provider_type = provider.provider_type.split("::")[1]
|
||||||
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
|
||||||
if len(safety_model_entries) == 0:
|
if len(safety_model_entries) == 0:
|
||||||
continue
|
continue
|
||||||
if provider.provider_id:
|
|
||||||
shield_id = provider.provider_id
|
env_var = f"ENABLE_{provider_type.upper().replace('-', '_').replace('::', '_')}"
|
||||||
else:
|
provider_id = f"${{env.{env_var}:=__disabled__}}"
|
||||||
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
|
|
||||||
for safety_model_entry in safety_model_entries:
|
available_models[provider_id] = safety_model_entries
|
||||||
print(f"provider.provider_id: {provider.provider_id}")
|
|
||||||
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
|
return available_models
|
||||||
shields.append(
|
|
||||||
ShieldInput(
|
|
||||||
provider_id="llama-guard",
|
|
||||||
shield_id=shield_id,
|
|
||||||
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return shields
|
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
|
@ -300,6 +241,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
provider_id="${env.ENABLE_PGVECTOR:=__disabled__}",
|
||||||
provider_type="remote::pgvector",
|
provider_type="remote::pgvector",
|
||||||
config=PGVectorVectorIOConfig.sample_run_config(
|
config=PGVectorVectorIOConfig.sample_run_config(
|
||||||
|
f"~/.llama/distributions/{name}",
|
||||||
db="${env.PGVECTOR_DB:=}",
|
db="${env.PGVECTOR_DB:=}",
|
||||||
user="${env.PGVECTOR_USER:=}",
|
user="${env.PGVECTOR_USER:=}",
|
||||||
password="${env.PGVECTOR_PASSWORD:=}",
|
password="${env.PGVECTOR_PASSWORD:=}",
|
||||||
|
|
@ -307,8 +249,6 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
shields = get_shields_for_providers(remote_inference_providers)
|
|
||||||
|
|
||||||
providers = {
|
providers = {
|
||||||
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
|
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
|
||||||
"vector_io": ([p.provider_type for p in vector_io_providers]),
|
"vector_io": ([p.provider_type for p in vector_io_providers]),
|
||||||
|
|
@ -361,7 +301,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, ids_conflict_in_models = get_model_registry(available_models)
|
||||||
|
|
||||||
|
available_safety_models = get_safety_models_for_providers(remote_inference_providers)
|
||||||
|
shields = get_shield_registry(available_safety_models, ids_conflict_in_models)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ from llama_stack.providers.utils.sqlstore.sqlstore import get_pip_packages as ge
|
||||||
|
|
||||||
def get_model_registry(
|
def get_model_registry(
|
||||||
available_models: dict[str, list[ProviderModelEntry]],
|
available_models: dict[str, list[ProviderModelEntry]],
|
||||||
) -> list[ModelInput]:
|
) -> tuple[list[ModelInput], bool]:
|
||||||
models = []
|
models = []
|
||||||
|
|
||||||
# check for conflicts in model ids
|
# check for conflicts in model ids
|
||||||
|
|
@ -74,7 +74,50 @@ def get_model_registry(
|
||||||
metadata=entry.metadata,
|
metadata=entry.metadata,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return models
|
return models, ids_conflict
|
||||||
|
|
||||||
|
|
||||||
|
def get_shield_registry(
|
||||||
|
available_safety_models: dict[str, list[ProviderModelEntry]],
|
||||||
|
ids_conflict_in_models: bool,
|
||||||
|
) -> list[ShieldInput]:
|
||||||
|
shields = []
|
||||||
|
|
||||||
|
# check for conflicts in shield ids
|
||||||
|
all_ids = set()
|
||||||
|
ids_conflict = False
|
||||||
|
|
||||||
|
for _, entries in available_safety_models.items():
|
||||||
|
for entry in entries:
|
||||||
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
|
for model_id in ids:
|
||||||
|
if model_id in all_ids:
|
||||||
|
ids_conflict = True
|
||||||
|
rich.print(
|
||||||
|
f"[yellow]Shield id {model_id} conflicts; all shield ids will be prefixed with provider id[/yellow]"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
all_ids.update(ids)
|
||||||
|
if ids_conflict:
|
||||||
|
break
|
||||||
|
if ids_conflict:
|
||||||
|
break
|
||||||
|
|
||||||
|
for provider_id, entries in available_safety_models.items():
|
||||||
|
for entry in entries:
|
||||||
|
ids = [entry.provider_model_id] + entry.aliases
|
||||||
|
for model_id in ids:
|
||||||
|
identifier = f"{provider_id}/{model_id}" if ids_conflict and provider_id not in model_id else model_id
|
||||||
|
shields.append(
|
||||||
|
ShieldInput(
|
||||||
|
shield_id=identifier,
|
||||||
|
provider_shield_id=f"{provider_id}/{entry.provider_model_id}"
|
||||||
|
if ids_conflict_in_models
|
||||||
|
else entry.provider_model_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return shields
|
||||||
|
|
||||||
|
|
||||||
class DefaultModel(BaseModel):
|
class DefaultModel(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
default_models = get_model_registry(available_models)
|
default_models, _ = get_model_registry(available_models)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="watsonx",
|
name="watsonx",
|
||||||
distro_type="remote_hosted",
|
distro_type="remote_hosted",
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,24 @@ def agent_config(llama_stack_client, text_model_id):
|
||||||
return agent_config
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def agent_config_without_safety(text_model_id):
|
||||||
|
agent_config = dict(
|
||||||
|
model=text_model_id,
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {
|
||||||
|
"type": "top_p",
|
||||||
|
"temperature": 0.0001,
|
||||||
|
"top_p": 0.9,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools=[],
|
||||||
|
enable_session_persistence=False,
|
||||||
|
)
|
||||||
|
return agent_config
|
||||||
|
|
||||||
|
|
||||||
def test_agent_simple(llama_stack_client, agent_config):
|
def test_agent_simple(llama_stack_client, agent_config):
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
agent = Agent(llama_stack_client, **agent_config)
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
@ -491,7 +509,7 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
|
||||||
assert expected_kw in response.output_message.content.lower()
|
assert expected_kw in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
def test_rag_agent_with_attachments(llama_stack_client, agent_config_without_safety):
|
||||||
urls = ["llama3.rst", "lora_finetune.rst"]
|
urls = ["llama3.rst", "lora_finetune.rst"]
|
||||||
documents = [
|
documents = [
|
||||||
# passign as url
|
# passign as url
|
||||||
|
|
@ -514,14 +532,8 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
metadata={},
|
metadata={},
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
rag_agent = Agent(llama_stack_client, **agent_config)
|
rag_agent = Agent(llama_stack_client, **agent_config_without_safety)
|
||||||
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
session_id = rag_agent.create_session(f"test-session-{uuid4()}")
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"Instead of the standard multi-head attention, what attention type does Llama3-8B use?",
|
|
||||||
"grouped",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
user_prompts = [
|
user_prompts = [
|
||||||
(
|
(
|
||||||
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
"I am attaching some documentation for Torchtune. Help me answer questions I will ask next.",
|
||||||
|
|
@ -549,82 +561,6 @@ def test_rag_agent_with_attachments(llama_stack_client, agent_config):
|
||||||
assert "lora" in response.output_message.content.lower()
|
assert "lora" in response.output_message.content.lower()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="Code interpreter is currently disabled in the Stack")
|
|
||||||
def test_rag_and_code_agent(llama_stack_client, agent_config):
|
|
||||||
if "llama-4" in agent_config["model"].lower():
|
|
||||||
pytest.xfail("Not working for llama4")
|
|
||||||
|
|
||||||
documents = []
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="nba_wiki",
|
|
||||||
content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
documents.append(
|
|
||||||
Document(
|
|
||||||
document_id="perplexity_wiki",
|
|
||||||
content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning:
|
|
||||||
|
|
||||||
Srinivas, the CEO, worked at OpenAI as an AI researcher.
|
|
||||||
Konwinski was among the founding team at Databricks.
|
|
||||||
Yarats, the CTO, was an AI research scientist at Meta.
|
|
||||||
Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
)
|
|
||||||
vector_db_id = f"test-vector-db-{uuid4()}"
|
|
||||||
llama_stack_client.vector_dbs.register(
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
|
||||||
embedding_dimension=384,
|
|
||||||
)
|
|
||||||
llama_stack_client.tool_runtime.rag_tool.insert(
|
|
||||||
documents=documents,
|
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
chunk_size_in_tokens=128,
|
|
||||||
)
|
|
||||||
agent_config = {
|
|
||||||
**agent_config,
|
|
||||||
"tools": [
|
|
||||||
dict(
|
|
||||||
name="builtin::rag/knowledge_search",
|
|
||||||
args={"vector_db_ids": [vector_db_id]},
|
|
||||||
),
|
|
||||||
"builtin::code_interpreter",
|
|
||||||
],
|
|
||||||
}
|
|
||||||
agent = Agent(llama_stack_client, **agent_config)
|
|
||||||
user_prompts = [
|
|
||||||
(
|
|
||||||
"when was Perplexity the company founded?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"2022",
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"when was the nba created?",
|
|
||||||
[],
|
|
||||||
"knowledge_search",
|
|
||||||
"1949",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
for prompt, docs, tool_name, expected_kw in user_prompts:
|
|
||||||
session_id = agent.create_session(f"test-session-{uuid4()}")
|
|
||||||
response = agent.create_turn(
|
|
||||||
messages=[{"role": "user", "content": prompt}],
|
|
||||||
session_id=session_id,
|
|
||||||
documents=docs,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution")
|
|
||||||
assert tool_execution_step.tool_calls[0].tool_name == tool_name, f"Failed on {prompt}"
|
|
||||||
if expected_kw:
|
|
||||||
assert expected_kw in response.output_message.content.lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"client_tools",
|
"client_tools",
|
||||||
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
[(get_boiling_point, False), (get_boiling_point_with_metadata, True)],
|
||||||
|
|
|
||||||
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
191
tests/unit/providers/vector_io/remote/test_milvus.py
Normal file
|
|
@ -0,0 +1,191 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||||
|
|
||||||
|
# Mock the entire pymilvus module
|
||||||
|
pymilvus_mock = MagicMock()
|
||||||
|
pymilvus_mock.DataType = MagicMock()
|
||||||
|
pymilvus_mock.MilvusClient = MagicMock
|
||||||
|
|
||||||
|
# Apply the mock before importing MilvusIndex
|
||||||
|
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||||
|
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||||
|
|
||||||
|
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||||
|
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||||
|
# tests/integration/vector_io/
|
||||||
|
#
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||||
|
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||||
|
|
||||||
|
MILVUS_PROVIDER = "milvus"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def mock_milvus_client() -> MagicMock:
|
||||||
|
"""Create a mock Milvus client with common method behaviors."""
|
||||||
|
client = MagicMock()
|
||||||
|
|
||||||
|
# Mock collection operations
|
||||||
|
client.has_collection.return_value = False # Initially no collection
|
||||||
|
client.create_collection.return_value = None
|
||||||
|
client.drop_collection.return_value = None
|
||||||
|
|
||||||
|
# Mock insert operation
|
||||||
|
client.insert.return_value = {"insert_count": 10}
|
||||||
|
|
||||||
|
# Mock search operation - return mock results (data should be dict, not JSON string)
|
||||||
|
client.search.return_value = [
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"id": 0,
|
||||||
|
"distance": 0.1,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": 1,
|
||||||
|
"distance": 0.2,
|
||||||
|
"entity": {"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock query operation for keyword search (data should be dict, not JSON string)
|
||||||
|
client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "mock chunk 1", "metadata": {"document_id": "doc1"}},
|
||||||
|
"score": 0.9,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "mock chunk 2", "metadata": {"document_id": "doc2"}},
|
||||||
|
"score": 0.8,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk3",
|
||||||
|
"chunk_content": {"content": "mock chunk 3", "metadata": {"document_id": "doc3"}},
|
||||||
|
"score": 0.7,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def milvus_index(mock_milvus_client):
|
||||||
|
"""Create a MilvusIndex with mocked client."""
|
||||||
|
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
|
||||||
|
yield index
|
||||||
|
# No real cleanup needed since we're using mocks
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
# Setup: collection doesn't exist initially, then exists after creation
|
||||||
|
mock_milvus_client.has_collection.side_effect = [False, True]
|
||||||
|
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Verify collection was created and data was inserted
|
||||||
|
mock_milvus_client.create_collection.assert_called_once()
|
||||||
|
mock_milvus_client.insert.assert_called_once()
|
||||||
|
|
||||||
|
# Verify the insert call had the right number of chunks
|
||||||
|
insert_call = mock_milvus_client.insert.call_args
|
||||||
|
assert len(insert_call[1]["data"]) == len(sample_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_vector(
|
||||||
|
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
|
||||||
|
):
|
||||||
|
# Setup: Add chunks first
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test vector search
|
||||||
|
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
|
||||||
|
response = await milvus_index.query_vector(query_embedding, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
mock_milvus_client.search.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Test keyword search
|
||||||
|
query_string = "Sentence 5"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=2, score_threshold=0.0)
|
||||||
|
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
|
||||||
|
"""Test that when BM25 search fails, the system falls back to simple text search."""
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
|
||||||
|
|
||||||
|
# Force BM25 search to fail
|
||||||
|
mock_milvus_client.search.side_effect = Exception("BM25 search not available")
|
||||||
|
|
||||||
|
# Mock simple text search results
|
||||||
|
mock_milvus_client.query.return_value = [
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk1",
|
||||||
|
"chunk_content": {"content": "Python programming language", "metadata": {"document_id": "doc1"}},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"chunk_id": "chunk2",
|
||||||
|
"chunk_content": {"content": "Machine learning algorithms", "metadata": {"document_id": "doc2"}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Test keyword search that should fall back to simple text search
|
||||||
|
query_string = "Python"
|
||||||
|
response = await milvus_index.query_keyword(query_string=query_string, k=3, score_threshold=0.0)
|
||||||
|
|
||||||
|
# Verify response structure
|
||||||
|
assert isinstance(response, QueryChunksResponse)
|
||||||
|
assert len(response.chunks) > 0, "Fallback search should return results"
|
||||||
|
|
||||||
|
# Verify that simple text search was used (query method called instead of search)
|
||||||
|
mock_milvus_client.query.assert_called_once()
|
||||||
|
mock_milvus_client.search.assert_called_once() # Called once but failed
|
||||||
|
|
||||||
|
# Verify the query uses parameterized filter with filter_params
|
||||||
|
query_call_args = mock_milvus_client.query.call_args
|
||||||
|
assert "filter" in query_call_args[1], "Query should include filter for text search"
|
||||||
|
assert "filter_params" in query_call_args[1], "Query should use parameterized filter"
|
||||||
|
assert query_call_args[1]["filter_params"]["content"] == "Python", "Filter params should contain the search term"
|
||||||
|
|
||||||
|
# Verify all returned chunks have score 1.0 (simple binary scoring)
|
||||||
|
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_collection(milvus_index, mock_milvus_client):
|
||||||
|
# Test collection deletion
|
||||||
|
mock_milvus_client.has_collection.return_value = True
|
||||||
|
|
||||||
|
await milvus_index.delete()
|
||||||
|
|
||||||
|
mock_milvus_client.drop_collection.assert_called_once_with(collection_name=milvus_index.collection_name)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue