Merge branch 'main' into allow-dynamic-models-ollama

This commit is contained in:
Matthew Farrellee 2025-07-14 17:29:18 -04:00
commit b6a334604c
113 changed files with 3795 additions and 3100 deletions

2
.github/CODEOWNERS vendored
View file

@ -2,4 +2,4 @@
# These owners will be the default owners for everything in # These owners will be the default owners for everything in
# the repo. Unless a later match takes precedence, # the repo. Unless a later match takes precedence,
* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist * @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf

View file

@ -7,3 +7,7 @@ runs:
shell: bash shell: bash
run: | run: |
docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models
# TODO: rebuild an ollama image with llama-guard3:1b
echo "Verifying Ollama status..."
timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done'
docker exec ollama ollama pull llama-guard3:1b

View file

@ -3,10 +3,10 @@ name: Installer CI
on: on:
pull_request: pull_request:
paths: paths:
- 'install.sh' - 'scripts/install.sh'
push: push:
paths: paths:
- 'install.sh' - 'scripts/install.sh'
schedule: schedule:
- cron: '0 2 * * *' # every day at 02:00 UTC - cron: '0 2 * * *' # every day at 02:00 UTC
@ -16,11 +16,11 @@ jobs:
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: Run ShellCheck on install.sh - name: Run ShellCheck on install.sh
run: shellcheck install.sh run: shellcheck scripts/install.sh
smoke-test: smoke-test:
needs: lint needs: lint
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2 - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
- name: Run installer end-to-end - name: Run installer end-to-end
run: ./install.sh run: ./scripts/install.sh

View file

@ -35,7 +35,7 @@ jobs:
- name: Install minikube - name: Install minikube
if: ${{ matrix.auth-provider == 'kubernetes' }} if: ${{ matrix.auth-provider == 'kubernetes' }}
uses: medyagh/setup-minikube@cea33675329b799adccc9526aa5daccc26cd5052 # v0.0.19 uses: medyagh/setup-minikube@e3c7f79eb1e997eabccc536a6cf318a2b0fe19d9 # v0.0.20
- name: Start minikube - name: Start minikube
if: ${{ matrix.auth-provider == 'oauth2_token' }} if: ${{ matrix.auth-provider == 'oauth2_token' }}

View file

@ -18,16 +18,33 @@ concurrency:
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
test-matrix: discover-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs:
test-type: ${{ steps.generate-matrix.outputs.test-type }}
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Generate test matrix
id: generate-matrix
run: |
# Get test directories dynamically, excluding non-test directories
TEST_TYPES=$(find tests/integration -maxdepth 1 -mindepth 1 -type d -printf "%f\n" |
grep -Ev "^(__pycache__|fixtures|test_cases)$" |
sort | jq -R -s -c 'split("\n")[:-1]')
echo "test-type=$TEST_TYPES" >> $GITHUB_OUTPUT
test-matrix:
needs: discover-tests
runs-on: ubuntu-latest
strategy: strategy:
fail-fast: false
matrix: matrix:
# Listing tests manually since some of them currently fail test-type: ${{ fromJson(needs.discover-tests.outputs.test-type) }}
# TODO: generate matrix list from tests/integration when fixed
test-type: [agents, inference, datasets, inspect, scoring, post_training, providers, tool_runtime, vector_io]
client-type: [library, server] client-type: [library, server]
python-version: ["3.12", "3.13"] python-version: ["3.12", "3.13"]
fail-fast: false # we want to run all tests regardless of failure
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -53,9 +70,11 @@ jobs:
- name: Run Integration Tests - name: Run Integration Tests
env: env:
OLLAMA_INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct" # for server tests OLLAMA_INFERENCE_MODEL: "llama3.2:3b-instruct-fp16" # for server tests
ENABLE_OLLAMA: "ollama" # for server tests ENABLE_OLLAMA: "ollama" # for server tests
OLLAMA_URL: "http://0.0.0.0:11434" OLLAMA_URL: "http://0.0.0.0:11434"
SAFETY_MODEL: "llama-guard3:1b"
LLAMA_STACK_CLIENT_TIMEOUT: "300" # Increased timeout for eval operations
# Use 'shell' to get pipefail behavior # Use 'shell' to get pipefail behavior
# https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference # https://docs.github.com/en/actions/reference/workflow-syntax-for-github-actions#exit-codes-and-error-action-preference
# TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash' # TODO: write a precommit hook to detect if a test contains a pipe but does not use 'shell: bash'
@ -68,8 +87,9 @@ jobs:
fi fi
uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \
-k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \
--text-model="ollama/meta-llama/Llama-3.2-3B-Instruct" \ --text-model="ollama/llama3.2:3b-instruct-fp16" \
--embedding-model=all-MiniLM-L6-v2 \ --embedding-model=all-MiniLM-L6-v2 \
--safety-shield=ollama \
--color=yes \ --color=yes \
--capture=tee-sys | tee pytest-${{ matrix.test-type }}.log --capture=tee-sys | tee pytest-${{ matrix.test-type }}.log

View file

@ -29,7 +29,7 @@ repos:
- id: check-toml - id: check-toml
- repo: https://github.com/Lucas-C/pre-commit-hooks - repo: https://github.com/Lucas-C/pre-commit-hooks
rev: v1.5.4 rev: v1.5.5
hooks: hooks:
- id: insert-license - id: insert-license
files: \.py$|\.sh$ files: \.py$|\.sh$
@ -38,7 +38,7 @@ repos:
- docs/license_header.txt - docs/license_header.txt
- repo: https://github.com/astral-sh/ruff-pre-commit - repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.9.4 rev: v0.12.2
hooks: hooks:
- id: ruff - id: ruff
args: [ --fix ] args: [ --fix ]
@ -46,14 +46,14 @@ repos:
- id: ruff-format - id: ruff-format
- repo: https://github.com/adamchainz/blacken-docs - repo: https://github.com/adamchainz/blacken-docs
rev: 1.19.0 rev: 1.19.1
hooks: hooks:
- id: blacken-docs - id: blacken-docs
additional_dependencies: additional_dependencies:
- black==24.3.0 - black==24.3.0
- repo: https://github.com/astral-sh/uv-pre-commit - repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.7.8 rev: 0.7.20
hooks: hooks:
- id: uv-lock - id: uv-lock
- id: uv-export - id: uv-export
@ -66,7 +66,7 @@ repos:
] ]
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.15.0 rev: v1.16.1
hooks: hooks:
- id: mypy - id: mypy
additional_dependencies: additional_dependencies:
@ -133,3 +133,8 @@ repos:
ci: ci:
autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks
autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate autoupdate_commit_msg: ⬆ [pre-commit.ci] pre-commit autoupdate
autofix_prs: true
autoupdate_branch: ''
autoupdate_schedule: weekly
skip: []
submodules: false

View file

@ -66,7 +66,7 @@ You can install the dependencies by running:
```bash ```bash
cd llama-stack cd llama-stack
uv sync --extra dev uv sync --group dev
uv pip install -e . uv pip install -e .
source .venv/bin/activate source .venv/bin/activate
``` ```
@ -168,7 +168,7 @@ manually as they are auto-generated.
### Updating the provider documentation ### Updating the provider documentation
If you have made changes to a provider's configuration, you should run `./scripts/distro_codegen.py` If you have made changes to a provider's configuration, you should run `./scripts/provider_codegen.py`
to re-generate the documentation. You should not change `docs/source/.../providers/` files manually to re-generate the documentation. You should not change `docs/source/.../providers/` files manually
as they are auto-generated. as they are auto-generated.
Note that the provider "description" field will be used to generate the provider documentation. Note that the provider "description" field will be used to generate the provider documentation.

View file

@ -77,7 +77,7 @@ As more providers start supporting Llama 4, you can use them in Llama Stack as w
To try Llama Stack locally, run: To try Llama Stack locally, run:
```bash ```bash
curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/install.sh | bash curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/scripts/install.sh | bash
``` ```
### Overview ### Overview

View file

@ -14796,7 +14796,8 @@
"description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\"" "description": "Template for formatting each retrieved chunk in the context. Available placeholders: {index} (1-based chunk ordinal), {chunk.content} (chunk content string), {metadata} (chunk metadata dict). Default: \"Result {index}\\nContent: {chunk.content}\\nMetadata: {metadata}\\n\""
}, },
"mode": { "mode": {
"type": "string", "$ref": "#/components/schemas/RAGSearchMode",
"default": "vector",
"description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"." "description": "Search mode for retrieval—either \"vector\", \"keyword\", or \"hybrid\". Default \"vector\"."
}, },
"ranker": { "ranker": {
@ -14831,6 +14832,16 @@
} }
} }
}, },
"RAGSearchMode": {
"type": "string",
"enum": [
"vector",
"keyword",
"hybrid"
],
"title": "RAGSearchMode",
"description": "Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search for semantic matching - KEYWORD: Uses keyword-based search for exact matching - HYBRID: Combines both vector and keyword search for better results"
},
"RRFRanker": { "RRFRanker": {
"type": "object", "type": "object",
"properties": { "properties": {

View file

@ -10346,7 +10346,8 @@ components:
content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent: content string), {metadata} (chunk metadata dict). Default: "Result {index}\nContent:
{chunk.content}\nMetadata: {metadata}\n" {chunk.content}\nMetadata: {metadata}\n"
mode: mode:
type: string $ref: '#/components/schemas/RAGSearchMode'
default: vector
description: >- description: >-
Search mode for retrieval—either "vector", "keyword", or "hybrid". Default Search mode for retrieval—either "vector", "keyword", or "hybrid". Default
"vector". "vector".
@ -10373,6 +10374,17 @@ components:
mapping: mapping:
default: '#/components/schemas/DefaultRAGQueryGeneratorConfig' default: '#/components/schemas/DefaultRAGQueryGeneratorConfig'
llm: '#/components/schemas/LLMRAGQueryGeneratorConfig' llm: '#/components/schemas/LLMRAGQueryGeneratorConfig'
RAGSearchMode:
type: string
enum:
- vector
- keyword
- hybrid
title: RAGSearchMode
description: >-
Search modes for RAG query retrieval: - VECTOR: Uses vector similarity search
for semantic matching - KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
RRFRanker: RRFRanker:
type: object type: object
properties: properties:

View file

@ -1,5 +1,7 @@
# The Llama Stack API # The Llama Stack API
*Originally authored Jul 23, 2024*
**Authors:** **Authors:**
* Meta: @raghotham, @ashwinb, @hjshah, @jspisak * Meta: @raghotham, @ashwinb, @hjshah, @jspisak
@ -24,7 +26,7 @@ Meta releases weights of both the pretrained and instruction fine-tuned Llama mo
### Model Lifecycle ### Model Lifecycle
![Figure 1: Model Life Cycle](../docs/resources/model-lifecycle.png) ![Figure 1: Model Life Cycle](resources/model-lifecycle.png)
For each of the operations that need to be performed (e.g. fine tuning, inference, evals etc) during the model life cycle, we identified the capabilities as toolchain APIs that are needed. Some of these capabilities are primitive operations like inference while other capabilities like synthetic data generation are composed of other capabilities. The list of APIs we have identified to support the lifecycle of Llama models is below: For each of the operations that need to be performed (e.g. fine tuning, inference, evals etc) during the model life cycle, we identified the capabilities as toolchain APIs that are needed. Some of these capabilities are primitive operations like inference while other capabilities like synthetic data generation are composed of other capabilities. The list of APIs we have identified to support the lifecycle of Llama models is below:
@ -37,7 +39,7 @@ For each of the operations that need to be performed (e.g. fine tuning, inferenc
### Agentic System ### Agentic System
![Figure 2: Agentic System](../docs/resources/agentic-system.png) ![Figure 2: Agentic System](resources/agentic-system.png)
In addition to the model lifecycle, we considered the different components involved in an agentic system. Specifically around tool calling and shields. Since the model may decide to call tools, a single model inference call is not enough. Whats needed is an agentic loop consisting of tool calls and inference. The model provides separate tokens representing end-of-message and end-of-turn. A message represents a possible stopping point for execution where the model can inform the execution environment that a tool call needs to be made. The execution environment, upon execution, adds back the result to the context window and makes another inference call. This process can get repeated until an end-of-turn token is generated. In addition to the model lifecycle, we considered the different components involved in an agentic system. Specifically around tool calling and shields. Since the model may decide to call tools, a single model inference call is not enough. Whats needed is an agentic loop consisting of tool calls and inference. The model provides separate tokens representing end-of-message and end-of-turn. A message represents a possible stopping point for execution where the model can inform the execution environment that a tool call needs to be made. The execution environment, upon execution, adds back the result to the context window and makes another inference call. This process can get repeated until an end-of-turn token is generated.
Note that as of today, in the OSS world, such a “loop” is often coded explicitly via elaborate prompt engineering using a ReAct pattern (typically) or preconstructed execution graph. Llama 3.1 (and future Llamas) attempts to absorb this multi-step reasoning loop inside the main model itself. Note that as of today, in the OSS world, such a “loop” is often coded explicitly via elaborate prompt engineering using a ReAct pattern (typically) or preconstructed execution graph. Llama 3.1 (and future Llamas) attempts to absorb this multi-step reasoning loop inside the main model itself.
@ -63,9 +65,9 @@ The sequence diagram that details the steps is [here](https://github.com/meta-ll
We define the Llama Stack as a layer cake shown below. We define the Llama Stack as a layer cake shown below.
![Figure 3: Llama Stack](../docs/resources/llama-stack.png) ![Figure 3: Llama Stack](resources/llama-stack.png)
The API is defined in the [YAML](../docs/_static/llama-stack-spec.yaml) and [HTML](../docs/_static/llama-stack-spec.html) files. The API is defined in the [YAML](_static/llama-stack-spec.yaml) and [HTML](_static/llama-stack-spec.html) files.
## Sample implementations ## Sample implementations

View file

@ -145,6 +145,10 @@ $ llama stack build --template starter
... ...
You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml` You can now edit ~/.llama/distributions/llamastack-starter/starter-run.yaml and run `llama stack run ~/.llama/distributions/llamastack-starter/starter-run.yaml`
``` ```
```{tip}
The generated `run.yaml` file is a starting point for your configuration. For comprehensive guidance on customizing it for your specific needs, infrastructure, and deployment scenarios, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md).
```
::: :::
:::{tab-item} Building from Scratch :::{tab-item} Building from Scratch

View file

@ -2,6 +2,10 @@
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution: The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
```{note}
The default `run.yaml` files generated by templates are starting points for your configuration. For guidance on customizing these files for your specific needs, see [Customizing Your run.yaml Configuration](customizing_run_yaml.md).
```
```{dropdown} 👋 Click here for a Sample Configuration File ```{dropdown} 👋 Click here for a Sample Configuration File
```yaml ```yaml

View file

@ -0,0 +1,40 @@
# Customizing run.yaml Files
The `run.yaml` files generated by Llama Stack templates are **starting points** designed to be customized for your specific needs. They are not meant to be used as-is in production environments.
## Key Points
- **Templates are starting points**: Generated `run.yaml` files contain defaults for development/testing
- **Customization expected**: Update URLs, credentials, models, and settings for your environment
- **Version control separately**: Keep customized configs in your own repository
- **Environment-specific**: Create different configurations for dev, staging, production
## What You Can Customize
You can customize:
- **Provider endpoints**: Change `http://localhost:8000` to your actual servers
- **Swap providers**: Replace default providers (e.g., swap Tavily with Brave for search)
- **Storage paths**: Move from `/tmp/` to production directories
- **Authentication**: Add API keys, SSL, timeouts
- **Models**: Different model sizes for dev vs prod
- **Database settings**: Switch from SQLite to PostgreSQL
- **Tool configurations**: Add custom tools and integrations
## Best Practices
- Use environment variables for secrets and environment-specific values
- Create separate `run.yaml` files for different environments (dev, staging, prod)
- Document your changes with comments
- Test configurations before deployment
- Keep your customized configs in version control
Example structure:
```
your-project/
├── configs/
│ ├── dev-run.yaml
│ ├── prod-run.yaml
└── README.md
```
The goal is to take the generated template and adapt it to your specific infrastructure and operational needs.

View file

@ -9,6 +9,7 @@ This section provides an overview of the distributions available in Llama Stack.
importing_as_library importing_as_library
configuration configuration
customizing_run_yaml
list_of_distributions list_of_distributions
kubernetes_deployment kubernetes_deployment
building_distro building_distro

View file

@ -6,12 +6,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
export POSTGRES_USER=${POSTGRES_USER:-llamastack} export POSTGRES_USER=llamastack
export POSTGRES_DB=${POSTGRES_DB:-llamastack} export POSTGRES_DB=llamastack
export POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-llamastack} export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct} export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B} export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
# HF_TOKEN should be set by the user; base64 encode it for the secret # HF_TOKEN should be set by the user; base64 encode it for the secret
if [ -n "${HF_TOKEN:-}" ]; then if [ -n "${HF_TOKEN:-}" ]; then

View file

@ -32,7 +32,7 @@ spec:
image: vllm/vllm-openai:latest image: vllm/vllm-openai:latest
command: ["/bin/sh", "-c"] command: ["/bin/sh", "-c"]
args: args:
- "vllm serve ${INFERENCE_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --gpu-memory-utilization 0.6" - "vllm serve ${INFERENCE_MODEL} --dtype float16 --enforce-eager --max-model-len 4096 --gpu-memory-utilization 0.6 --enable-auto-tool-choice --tool-call-parser llama4_pythonic"
env: env:
- name: INFERENCE_MODEL - name: INFERENCE_MODEL
value: "${INFERENCE_MODEL}" value: "${INFERENCE_MODEL}"

View file

@ -13,7 +13,7 @@ Latest Release Notes: [link](https://github.com/meta-llama/llama-stack-client-ko
*Tagged releases are stable versions of the project. While we strive to maintain a stable main branch, it's not guaranteed to be free of bugs or issues.* *Tagged releases are stable versions of the project. While we strive to maintain a stable main branch, it's not guaranteed to be free of bugs or issues.*
## Android Demo App ## Android Demo App
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-client-kotlin/tree/examples/android_app) Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app)
The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments. The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments.
@ -68,7 +68,7 @@ Ensure the Llama Stack server version is the same as the Kotlin SDK Library for
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations) Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#settings) How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#settings)
### Initialize the Client ### Initialize the Client
A client serves as the primary interface for interacting with a specific inference type and its associated parameters. Only after client is initialized then you can configure and start inferences. A client serves as the primary interface for interacting with a specific inference type and its associated parameters. Only after client is initialized then you can configure and start inferences.
@ -135,7 +135,7 @@ val result = client!!.inference().chatCompletionStreaming(
### Setup Custom Tool Calling ### Setup Custom Tool Calling
Android demo app for more details: [Custom Tool Calling](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#tool-calling) Android demo app for more details: [Custom Tool Calling](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#tool-calling)
## Advanced Users ## Advanced Users

View file

@ -54,7 +54,7 @@ Llama Stack is a server that exposes multiple APIs, you connect with it using th
You can use Python to build and run the Llama Stack server, which is useful for testing and development. You can use Python to build and run the Llama Stack server, which is useful for testing and development.
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup, Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
which defines the providers and their settings. which defines the providers and their settings. The generated configuration serves as a starting point that you can [customize for your specific needs](../distributions/customizing_run_yaml.md).
Now let's build and run the Llama Stack config for Ollama. Now let's build and run the Llama Stack config for Ollama.
We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables. We use `starter` as template. By default all providers are disabled, this requires enable ollama by passing environment variables.
@ -77,7 +77,7 @@ ENABLE_OLLAMA=ollama INFERENCE_MODEL="llama3.2:3b" llama stack build --template
You can use a container image to run the Llama Stack server. We provide several container images for the server You can use a container image to run the Llama Stack server. We provide several container images for the server
component that works with different inference providers out of the box. For this guide, we will use component that works with different inference providers out of the box. For this guide, we will use
`llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the `llamastack/distribution-starter` as the container image. If you'd like to build your own image or customize the
configurations, please check out [this guide](../references/index.md). configurations, please check out [this guide](../distributions/building_distro.md).
First lets setup some environment variables and create a local directory to mount into the containers file system. First lets setup some environment variables and create a local directory to mount into the containers file system.
```bash ```bash
export INFERENCE_MODEL="llama3.2:3b" export INFERENCE_MODEL="llama3.2:3b"

View file

@ -11,7 +11,7 @@ Please refer to the remote provider documentation.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | | | `db_path` | `<class 'str'>` | No | PydanticUndefined | |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | | | `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
| `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server | | `consistency_level` | `<class 'str'>` | No | Strong | The consistency level of the Milvus server |
## Sample Configuration ## Sample Configuration

View file

@ -205,12 +205,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | | | `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
``` ```

View file

@ -10,12 +10,16 @@ Please refer to the sqlite-vec provider documentation.
| Field | Type | Required | Default | Description | | Field | Type | Required | Default | Description |
|-------|------|----------|---------|-------------| |-------|------|----------|---------|-------------|
| `db_path` | `<class 'str'>` | No | PydanticUndefined | | | `db_path` | `<class 'str'>` | No | PydanticUndefined | Path to the SQLite database file |
| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig` | No | sqlite | Config for KV store backend (SQLite only for now) |
## Sample Configuration ## Sample Configuration
```yaml ```yaml
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/sqlite_vec_registry.db
``` ```

View file

@ -87,6 +87,20 @@ class RAGQueryGenerator(Enum):
custom = "custom" custom = "custom"
@json_schema_type
class RAGSearchMode(Enum):
"""
Search modes for RAG query retrieval:
- VECTOR: Uses vector similarity search for semantic matching
- KEYWORD: Uses keyword-based search for exact matching
- HYBRID: Combines both vector and keyword search for better results
"""
VECTOR = "vector"
KEYWORD = "keyword"
HYBRID = "hybrid"
@json_schema_type @json_schema_type
class DefaultRAGQueryGeneratorConfig(BaseModel): class DefaultRAGQueryGeneratorConfig(BaseModel):
type: Literal["default"] = "default" type: Literal["default"] = "default"
@ -128,7 +142,7 @@ class RAGQueryConfig(BaseModel):
max_tokens_in_context: int = 4096 max_tokens_in_context: int = 4096
max_chunks: int = 5 max_chunks: int = 5
chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n" chunk_template: str = "Result {index}\nContent: {chunk.content}\nMetadata: {metadata}\n"
mode: str | None = None mode: RAGSearchMode | None = RAGSearchMode.VECTOR
ranker: Ranker | None = Field(default=None) # Only used for hybrid mode ranker: Ranker | None = Field(default=None) # Only used for hybrid mode
@field_validator("chunk_template") @field_validator("chunk_template")

View file

@ -93,7 +93,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
elif args.providers: elif args.providers:
providers = dict() providers_list: dict[str, str | list[str]] = dict()
for api_provider in args.providers.split(","): for api_provider in args.providers.split(","):
if "=" not in api_provider: if "=" not in api_provider:
cprint( cprint(
@ -112,7 +112,15 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
if provider in providers_for_api: if provider in providers_for_api:
providers.setdefault(api, []).append(provider) if api not in providers_list:
providers_list[api] = []
# Use type guarding to ensure we have a list
provider_value = providers_list[api]
if isinstance(provider_value, list):
provider_value.append(provider)
else:
# Convert string to list and append
providers_list[api] = [provider_value, provider]
else: else:
cprint( cprint(
f"{provider} is not a valid provider for the {api} API.", f"{provider} is not a valid provider for the {api} API.",
@ -121,7 +129,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
) )
sys.exit(1) sys.exit(1)
distribution_spec = DistributionSpec( distribution_spec = DistributionSpec(
providers=providers, providers=providers_list,
description=",".join(args.providers), description=",".join(args.providers),
) )
if not args.image_type: if not args.image_type:
@ -182,7 +190,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr) cprint("Tip: use <TAB> to see options for the providers.\n", color="green", file=sys.stderr)
providers = dict() providers: dict[str, str | list[str]] = dict()
for api, providers_for_api in get_provider_registry().items(): for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")] available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers: if not available_providers:
@ -371,10 +379,16 @@ def _run_stack_build_command_from_build_config(
if not image_name: if not image_name:
raise ValueError("Please specify an image name when building a venv image") raise ValueError("Please specify an image name when building a venv image")
# At this point, image_name should be guaranteed to be a string
if image_name is None:
raise ValueError("image_name should not be None after validation")
if template_name: if template_name:
build_dir = DISTRIBS_BASE_DIR / template_name build_dir = DISTRIBS_BASE_DIR / template_name
build_file_path = build_dir / f"{template_name}-build.yaml" build_file_path = build_dir / f"{template_name}-build.yaml"
else: else:
if image_name is None:
raise ValueError("image_name cannot be None")
build_dir = DISTRIBS_BASE_DIR / image_name build_dir = DISTRIBS_BASE_DIR / image_name
build_file_path = build_dir / f"{image_name}-build.yaml" build_file_path = build_dir / f"{image_name}-build.yaml"
@ -395,7 +409,7 @@ def _run_stack_build_command_from_build_config(
build_file_path, build_file_path,
image_name, image_name,
template_or_config=template_name or config_path or str(build_file_path), template_or_config=template_name or config_path or str(build_file_path),
run_config=run_config_file, run_config=run_config_file.as_posix() if run_config_file else None,
) )
if return_code != 0: if return_code != 0:
raise RuntimeError(f"Failed to build image {image_name}") raise RuntimeError(f"Failed to build image {image_name}")

View file

@ -83,46 +83,57 @@ class StackRun(Subcommand):
return ImageType.CONDA.value, args.image_name return ImageType.CONDA.value, args.image_name
return args.image_type, args.image_name return args.image_type, args.image_name
def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]:
"""Resolve config file path and template name from args.config"""
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
if not args.config:
return None, None
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
return config_file, template_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml import yaml
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_command from llama_stack.distribution.utils.exec import formulate_run_args, run_command
if args.enable_ui: if args.enable_ui:
self._start_ui_development_server(args.port) self._start_ui_development_server(args.port)
image_type, image_name = self._get_image_type_and_name(args) image_type, image_name = self._get_image_type_and_name(args)
# Resolve config file and template name first
config_file, template_name = self._resolve_config_and_template(args)
# Check if config is required based on image type # Check if config is required based on image type
if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not args.config: if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file:
self.parser.error("Config file is required for venv and conda environments") self.parser.error("Config file is required for venv and conda environments")
if args.config: if config_file:
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
template_name = None
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
if config_file.exists():
template_name = args.config
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml")
if not config_file.exists():
self.parser.error(
f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file"
)
if not config_file.is_file():
self.parser.error(
f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}"
)
logger.info(f"Using run configuration: {config_file}") logger.info(f"Using run configuration: {config_file}")
try: try:
@ -138,8 +149,6 @@ class StackRun(Subcommand):
self.parser.error(f"failed to parse config file '{config_file}':\n {e}") self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
else: else:
config = None config = None
config_file = None
template_name = None
# If neither image type nor image name is provided, assume the server should be run directly # If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages. # using the current environment packages.
@ -172,10 +181,7 @@ class StackRun(Subcommand):
run_args.extend([str(args.port)]) run_args.extend([str(args.port)])
if config_file: if config_file:
if template_name: run_args.extend(["--config", str(config_file)])
run_args.extend(["--template", str(template_name)])
else:
run_args.extend(["--config", str(config_file)])
if args.env: if args.env:
for env_var in args.env: for env_var in args.env:

View file

@ -81,7 +81,7 @@ def is_action_allowed(
if not len(policy): if not len(policy):
policy = default_policy() policy = default_policy()
qualified_resource_id = resource.type + "::" + resource.identifier qualified_resource_id = f"{resource.type}::{resource.identifier}"
for rule in policy: for rule in policy:
if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal): if rule.forbid and matches_scope(rule.forbid, action, qualified_resource_id, user.principal):
if rule.when: if rule.when:

View file

@ -96,7 +96,7 @@ FROM $container_base
WORKDIR /app WORKDIR /app
# We install the Python 3.12 dev headers and build tools so that any # We install the Python 3.12 dev headers and build tools so that any
# Cextension wheels (e.g. polyleven, faisscpu) can compile successfully. # C-extension wheels (e.g. polyleven, faiss-cpu) can compile successfully.
RUN dnf -y update && dnf install -y iputils git net-tools wget \ RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \ vim-minimal python3.12 python3.12-pip python3.12-wheel \
@ -169,7 +169,7 @@ if [ -n "$run_config" ]; then
echo "Copying external providers directory: $external_providers_dir" echo "Copying external providers directory: $external_providers_dir"
cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d" cp -r "$external_providers_dir" "$BUILD_CONTEXT_DIR/providers.d"
add_to_container << EOF add_to_container << EOF
COPY --chmod=g+w providers.d /.llama/providers.d COPY providers.d /.llama/providers.d
EOF EOF
fi fi

View file

@ -445,7 +445,7 @@ def main(args: argparse.Namespace | None = None):
logger.info(log_line) logger.info(log_line)
logger.info("Run configuration:") logger.info("Run configuration:")
safe_config = redact_sensitive_fields(config.model_dump()) safe_config = redact_sensitive_fields(config.model_dump(mode="json"))
logger.info(yaml.dump(safe_config, indent=2)) logger.info(yaml.dump(safe_config, indent=2))
app = FastAPI( app = FastAPI(

View file

@ -98,6 +98,7 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
method = getattr(impls[api], register_method) method = getattr(impls[api], register_method)
for obj in objects: for obj in objects:
logger.debug(f"registering {rsrc.capitalize()} {obj} for provider {obj.provider_id}")
# Do not register models on disabled providers # Do not register models on disabled providers
if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__": if hasattr(obj, "provider_id") and obj.provider_id is not None and obj.provider_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.") logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled provider.")
@ -112,6 +113,11 @@ async def register_resources(run_config: StackRunConfig, impls: dict[Api, Any]):
): ):
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.") logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled model.")
continue continue
if hasattr(obj, "shield_id") and obj.shield_id is not None and obj.shield_id == "__disabled__":
logger.debug(f"Skipping {rsrc.capitalize()} registration for disabled shield.")
continue
# we want to maintain the type information in arguments to method. # we want to maintain the type information in arguments to method.
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict, # instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
# we use model_dump() to find all the attrs and then getattr to get the still typed value. # we use model_dump() to find all the attrs and then getattr to get the still typed value.

View file

@ -6,12 +6,9 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextvars import ContextVar from contextvars import ContextVar
from typing import TypeVar
T = TypeVar("T")
def preserve_contexts_async_generator( def preserve_contexts_async_generator[T](
gen: AsyncGenerator[T, None], context_vars: list[ContextVar] gen: AsyncGenerator[T, None], context_vars: list[ContextVar]
) -> AsyncGenerator[T, None]: ) -> AsyncGenerator[T, None]:
""" """

View file

@ -123,7 +123,8 @@ class TorchtunePostTrainingImpl:
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: dict[str, Any], hyperparam_search_config: dict[str, Any],
logger_config: dict[str, Any], logger_config: dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob:
raise NotImplementedError()
async def get_training_jobs(self) -> ListPostTrainingJobsResponse: async def get_training_jobs(self) -> ListPostTrainingJobsResponse:
return ListPostTrainingJobsResponse( return ListPostTrainingJobsResponse(

View file

@ -146,10 +146,9 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
pass pass
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
if shield.provider_resource_id not in LLAMA_GUARD_MODEL_IDS: # Allow any model to be registered as a shield
raise ValueError( # The model will be validated during runtime when making inference calls
f"Unsupported Llama Guard type: {shield.provider_resource_id}. Allowed types: {LLAMA_GUARD_MODEL_IDS}" pass
)
async def run_shield( async def run_shield(
self, self,
@ -167,11 +166,25 @@ class LlamaGuardSafetyImpl(Safety, ShieldsProtocolPrivate):
if len(messages) > 0 and messages[0].role != Role.user.value: if len(messages) > 0 and messages[0].role != Role.user.value:
messages[0] = UserMessage(content=messages[0].content) messages[0] = UserMessage(content=messages[0].content)
model = LLAMA_GUARD_MODEL_IDS[shield.provider_resource_id] # Use the inference API's model resolution instead of hardcoded mappings
# This allows the shield to work with any registered model
model_id = shield.provider_resource_id
# Determine safety categories based on the model type
# For known Llama Guard models, use specific categories
if model_id in LLAMA_GUARD_MODEL_IDS:
# Use the mapped model for categories but the original model_id for inference
mapped_model = LLAMA_GUARD_MODEL_IDS[model_id]
safety_categories = MODEL_TO_SAFETY_CATEGORIES_MAP.get(mapped_model, DEFAULT_LG_V3_SAFETY_CATEGORIES)
else:
# For unknown models, use default Llama Guard 3 8B categories
safety_categories = DEFAULT_LG_V3_SAFETY_CATEGORIES + [CAT_CODE_INTERPRETER_ABUSE]
impl = LlamaGuardShield( impl = LlamaGuardShield(
model=model, model=model_id,
inference_api=self.inference_api, inference_api=self.inference_api,
excluded_categories=self.config.excluded_categories, excluded_categories=self.config.excluded_categories,
safety_categories=safety_categories,
) )
return await impl.run(messages) return await impl.run(messages)
@ -183,20 +196,21 @@ class LlamaGuardShield:
model: str, model: str,
inference_api: Inference, inference_api: Inference,
excluded_categories: list[str] | None = None, excluded_categories: list[str] | None = None,
safety_categories: list[str] | None = None,
): ):
if excluded_categories is None: if excluded_categories is None:
excluded_categories = [] excluded_categories = []
if safety_categories is None:
safety_categories = []
assert len(excluded_categories) == 0 or all( assert len(excluded_categories) == 0 or all(
x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories x in SAFETY_CATEGORIES_TO_CODE_MAP.values() for x in excluded_categories
), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]" ), "Invalid categories in excluded categories. Expected format is ['S1', 'S2', ..]"
if model not in MODEL_TO_SAFETY_CATEGORIES_MAP:
raise ValueError(f"Unsupported model: {model}")
self.model = model self.model = model
self.inference_api = inference_api self.inference_api = inference_api
self.excluded_categories = excluded_categories self.excluded_categories = excluded_categories
self.safety_categories = safety_categories
def check_unsafe_response(self, response: str) -> str | None: def check_unsafe_response(self, response: str) -> str | None:
match = re.match(r"^unsafe\n(.*)$", response) match = re.match(r"^unsafe\n(.*)$", response)
@ -214,7 +228,7 @@ class LlamaGuardShield:
final_categories = [] final_categories = []
all_categories = MODEL_TO_SAFETY_CATEGORIES_MAP[self.model] all_categories = self.safety_categories
for cat in all_categories: for cat in all_categories:
cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat] cat_code = SAFETY_CATEGORIES_TO_CODE_MAP[cat]
if cat_code in excluded_categories: if cat_code in excluded_categories:

View file

@ -267,6 +267,7 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info)) await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]: async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from kvstore.""" """Load all vector store metadata from kvstore."""
@ -286,17 +287,20 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info)) await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None: async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from kvstore.""" """Delete vector store metadata from kvstore."""
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}" key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key) await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _save_openai_vector_store_file( async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None: ) -> None:
"""Save vector store file metadata to kvstore.""" """Save vector store file data to kvstore."""
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}" key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.set(key=key, value=json.dumps(file_info)) await self.kvstore.set(key=key, value=json.dumps(file_info))
@ -324,7 +328,16 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
await self.kvstore.set(key=key, value=json.dumps(file_info)) await self.kvstore.set(key=key, value=json.dumps(file_info))
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
"""Delete vector store file metadata from kvstore.""" """Delete vector store data from kvstore."""
assert self.kvstore is not None assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
await self.kvstore.delete(key) keys_to_delete = [
f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}",
f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}",
]
for key in keys_to_delete:
try:
await self.kvstore.delete(key)
except Exception as e:
logger.warning(f"Failed to delete key {key}: {e}")
continue

View file

@ -18,7 +18,7 @@ from llama_stack.schema_utils import json_schema_type
@json_schema_type @json_schema_type
class MilvusVectorIOConfig(BaseModel): class MilvusVectorIOConfig(BaseModel):
db_path: str db_path: str
kvstore: KVStoreConfig kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong") consistency_level: str = Field(description="The consistency level of the Milvus server", default="Strong")
@classmethod @classmethod

View file

@ -6,14 +6,24 @@
from typing import Any from typing import Any
from pydantic import BaseModel from pydantic import BaseModel, Field
from llama_stack.providers.utils.kvstore.config import (
KVStoreConfig,
SqliteKVStoreConfig,
)
class SQLiteVectorIOConfig(BaseModel): class SQLiteVectorIOConfig(BaseModel):
db_path: str db_path: str = Field(description="Path to the SQLite database file")
kvstore: KVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
@classmethod @classmethod
def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]: def sample_run_config(cls, __distro_dir__: str) -> dict[str, Any]:
return { return {
"db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db", "db_path": "${env.SQLITE_STORE_DIR:=" + __distro_dir__ + "}/" + "sqlite_vec.db",
"kvstore": SqliteKVStoreConfig.sample_run_config(
__distro_dir__=__distro_dir__,
db_name="sqlite_vec_registry.db",
),
} }

View file

@ -24,6 +24,8 @@ from llama_stack.apis.vector_io import (
VectorIO, VectorIO,
) )
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
RERANKER_TYPE_RRF, RERANKER_TYPE_RRF,
@ -40,6 +42,13 @@ KEYWORD_SEARCH = "keyword"
HYBRID_SEARCH = "hybrid" HYBRID_SEARCH = "hybrid"
SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH} SEARCH_MODES = {VECTOR_SEARCH, KEYWORD_SEARCH, HYBRID_SEARCH}
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:sqlite_vec:{VERSION}::"
VECTOR_INDEX_PREFIX = f"vector_index:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:sqlite_vec:{VERSION}::"
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:sqlite_vec:{VERSION}::"
def serialize_vector(vector: list[float]) -> bytes: def serialize_vector(vector: list[float]) -> bytes:
"""Serialize a list of floats into a compact binary representation.""" """Serialize a list of floats into a compact binary representation."""
@ -117,13 +126,14 @@ class SQLiteVecIndex(EmbeddingIndex):
- An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search. - An FTS5 table (fts_chunks_{bank_id}) for full-text keyword search.
""" """
def __init__(self, dimension: int, db_path: str, bank_id: str): def __init__(self, dimension: int, db_path: str, bank_id: str, kvstore: KVStore | None = None):
self.dimension = dimension self.dimension = dimension
self.db_path = db_path self.db_path = db_path
self.bank_id = bank_id self.bank_id = bank_id
self.metadata_table = f"chunks_{bank_id}".replace("-", "_") self.metadata_table = f"chunks_{bank_id}".replace("-", "_")
self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_")
self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_")
self.kvstore = kvstore
@classmethod @classmethod
async def create(cls, dimension: int, db_path: str, bank_id: str): async def create(cls, dimension: int, db_path: str, bank_id: str):
@ -425,27 +435,116 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
self.files_api = files_api self.files_api = files_api
self.cache: dict[str, VectorDBWithIndex] = {} self.cache: dict[str, VectorDBWithIndex] = {}
self.openai_vector_stores: dict[str, dict[str, Any]] = {} self.openai_vector_stores: dict[str, dict[str, Any]] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None: async def initialize(self) -> None:
def _setup_connection(): self.kvstore = await kvstore_impl(self.config.kvstore)
# Open a connection to the SQLite database (the file is specified in the config).
start_key = VECTOR_DBS_PREFIX
end_key = f"{VECTOR_DBS_PREFIX}\xff"
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
for db_json in stored_vector_dbs:
vector_db = VectorDB.model_validate_json(db_json)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# load any existing OpenAI vector stores
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def register_vector_db(self, vector_db: VectorDB) -> None:
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=SQLiteVecIndex(
dimension=vector_db.embedding_dimension,
db_path=self.config.db_path,
bank_id=vector_db.identifier,
kvstore=self.kvstore,
),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
assert self.kvstore is not None
start_key = OPENAI_VECTOR_STORES_PREFIX
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
stored_openai_stores = await self.kvstore.values_in_range(start_key, end_key)
stores = {}
for store_data in stored_openai_stores:
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.set(key=key, value=json.dumps(store_info))
self.openai_vector_stores[store_id] = store_info
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
assert self.kvstore is not None
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
await self.kvstore.delete(key)
if store_id in self.openai_vector_stores:
del self.openai_vector_stores[store_id]
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _create_or_store():
connection = _create_sqlite_connection(self.config.db_path) connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor() cur = connection.cursor()
try: try:
# Create a table to persist vector DB registrations.
cur.execute("""
CREATE TABLE IF NOT EXISTS vector_dbs (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector stores.
cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_stores (
id TEXT PRIMARY KEY,
metadata TEXT
);
""")
# Create a table to persist OpenAI vector store files. # Create a table to persist OpenAI vector store files.
cur.execute(""" cur.execute("""
CREATE TABLE IF NOT EXISTS openai_vector_store_files ( CREATE TABLE IF NOT EXISTS openai_vector_store_files (
@ -464,168 +563,6 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
); );
""") """)
connection.commit() connection.commit()
# Load any existing vector DB registrations.
cur.execute("SELECT metadata FROM vector_dbs")
vector_db_rows = cur.fetchall()
return vector_db_rows
finally:
cur.close()
connection.close()
vector_db_rows = await asyncio.to_thread(_setup_connection)
# Load existing vector DBs
for row in vector_db_rows:
vector_db_data = row[0]
vector_db = VectorDB.model_validate_json(vector_db_data)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
# Load existing OpenAI vector stores using the mixin method
self.openai_vector_stores = await self._load_openai_vector_stores()
async def shutdown(self) -> None:
# nothing to do since we don't maintain a persistent connection
pass
async def register_vector_db(self, vector_db: VectorDB) -> None:
def _register_db():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)",
(vector_db.identifier, vector_db.model_dump_json()),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_register_db)
index = await SQLiteVecIndex.create(
vector_db.embedding_dimension,
self.config.db_path,
vector_db.identifier,
)
self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api)
async def list_vector_dbs(self) -> list[VectorDB]:
return [v.vector_db for v in self.cache.values()]
async def unregister_vector_db(self, vector_db_id: str) -> None:
if vector_db_id not in self.cache:
logger.warning(f"Vector DB {vector_db_id} not found")
return
await self.cache[vector_db_id].index.delete()
del self.cache[vector_db_id]
def _delete_vector_db_from_registry():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete_vector_db_from_registry)
# OpenAI Vector Store Mixin abstract method implementations
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Save vector store metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"INSERT OR REPLACE INTO openai_vector_stores (id, metadata) VALUES (?, ?)",
(store_id, json.dumps(store_info)),
)
connection.commit()
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
finally:
cur.close()
connection.close()
try:
await asyncio.to_thread(_store)
except Exception as e:
logger.error(f"Error saving openai vector store {store_id}: {e}")
raise
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
"""Load all vector store metadata from SQLite database."""
def _load():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("SELECT metadata FROM openai_vector_stores")
rows = cur.fetchall()
return rows
finally:
cur.close()
connection.close()
rows = await asyncio.to_thread(_load)
stores = {}
for row in rows:
store_data = row[0]
store_info = json.loads(store_data)
stores[store_info["id"]] = store_info
return stores
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
"""Update vector store metadata in SQLite database."""
def _update():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute(
"UPDATE openai_vector_stores SET metadata = ? WHERE id = ?",
(json.dumps(store_info), store_id),
)
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_update)
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
"""Delete vector store metadata from SQLite database."""
def _delete():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute("DELETE FROM openai_vector_stores WHERE id = ?", (store_id,))
connection.commit()
finally:
cur.close()
connection.close()
await asyncio.to_thread(_delete)
async def _save_openai_vector_store_file(
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
) -> None:
"""Save vector store file metadata to SQLite database."""
def _store():
connection = _create_sqlite_connection(self.config.db_path)
cur = connection.cursor()
try:
cur.execute( cur.execute(
"INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)", "INSERT OR REPLACE INTO openai_vector_store_files (store_id, file_id, metadata) VALUES (?, ?, ?)",
(store_id, file_id, json.dumps(file_info)), (store_id, file_id, json.dumps(file_info)),
@ -643,7 +580,7 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
connection.close() connection.close()
try: try:
await asyncio.to_thread(_store) await asyncio.to_thread(_create_or_store)
except Exception as e: except Exception as e:
logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}") logger.error(f"Error saving openai vector store file {store_id} {file_id}: {e}")
raise raise
@ -722,6 +659,10 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
cur.execute( cur.execute(
"DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id) "DELETE FROM openai_vector_store_files WHERE store_id = ? AND file_id = ?", (store_id, file_id)
) )
cur.execute(
"DELETE FROM openai_vector_store_files_contents WHERE store_id = ? AND file_id = ?",
(store_id, file_id),
)
connection.commit() connection.commit()
finally: finally:
cur.close() cur.close()
@ -730,15 +671,17 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
await asyncio.to_thread(_delete) await asyncio.to_thread(_delete)
async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None: async def insert_chunks(self, vector_db_id: str, chunks: list[Chunk], ttl_seconds: int | None = None) -> None:
if vector_db_id not in self.cache: index = await self._get_and_cache_vector_db_index(vector_db_id)
raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
# The VectorDBWithIndex helper is expected to compute embeddings via the inference_api # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api
# and then call our index's add_chunks. # and then call our index's add_chunks.
await self.cache[vector_db_id].insert_chunks(chunks) await index.insert_chunks(chunks)
async def query_chunks( async def query_chunks(
self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None self, vector_db_id: str, query: Any, params: dict[str, Any] | None = None
) -> QueryChunksResponse: ) -> QueryChunksResponse:
if vector_db_id not in self.cache: index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")
return await self.cache[vector_db_id].query_chunks(query, params) return await index.query_chunks(query, params)

View file

@ -15,21 +15,26 @@ LLM_MODEL_IDS = [
"anthropic/claude-3-5-haiku-latest", "anthropic/claude-3-5-haiku-latest",
] ]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ MODEL_ENTRIES = (
ProviderModelEntry( [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
provider_model_id="anthropic/voyage-3", + [
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 1024, "context_length": 32000}, provider_model_id="anthropic/voyage-3",
), model_type=ModelType.embedding,
ProviderModelEntry( metadata={"embedding_dimension": 1024, "context_length": 32000},
provider_model_id="anthropic/voyage-3-lite", ),
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 512, "context_length": 32000}, provider_model_id="anthropic/voyage-3-lite",
), model_type=ModelType.embedding,
ProviderModelEntry( metadata={"embedding_dimension": 512, "context_length": 32000},
provider_model_id="anthropic/voyage-code-3", ),
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 1024, "context_length": 32000}, provider_model_id="anthropic/voyage-code-3",
), model_type=ModelType.embedding,
] metadata={"embedding_dimension": 1024, "context_length": 32000},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -9,6 +9,10 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta.llama3-1-8b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0",
@ -22,4 +26,4 @@ MODEL_ENTRIES = [
"meta.llama3-1-405b-instruct-v1:0", "meta.llama3-1-405b-instruct-v1:0",
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -9,6 +9,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://inference-docs.cerebras.ai/models
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"llama3.1-8b", "llama3.1-8b",
@ -18,4 +21,8 @@ MODEL_ENTRIES = [
"llama-3.3-70b", "llama-3.3-70b",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
] build_hf_repo_model_entry(
"llama-4-scout-17b-16e-instruct",
CoreModelId.llama4_scout_17b_16e_instruct.value,
),
] + SAFETY_MODELS_ENTRIES

View file

@ -47,7 +47,10 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import DatabricksImplConfig from .config import DatabricksImplConfig
model_entries = [ SAFETY_MODELS_ENTRIES = []
# https://docs.databricks.com/aws/en/machine-learning/model-serving/foundation-model-overview
MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"databricks-meta-llama-3-1-70b-instruct", "databricks-meta-llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
@ -56,7 +59,7 @@ model_entries = [
"databricks-meta-llama-3-1-405b-instruct", "databricks-meta-llama-3-1-405b-instruct",
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES
class DatabricksInferenceAdapter( class DatabricksInferenceAdapter(
@ -66,7 +69,7 @@ class DatabricksInferenceAdapter(
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
): ):
def __init__(self, config: DatabricksImplConfig) -> None: def __init__(self, config: DatabricksImplConfig) -> None:
ModelRegistryHelper.__init__(self, model_entries=model_entries) ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:

View file

@ -11,6 +11,17 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"accounts/fireworks/models/llama-v3p1-8b-instruct", "accounts/fireworks/models/llama-v3p1-8b-instruct",
@ -40,14 +51,6 @@ MODEL_ENTRIES = [
"accounts/fireworks/models/llama-v3p3-70b-instruct", "accounts/fireworks/models/llama-v3p3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-8b",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"accounts/fireworks/models/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
build_hf_repo_model_entry( build_hf_repo_model_entry(
"accounts/fireworks/models/llama4-scout-instruct-basic", "accounts/fireworks/models/llama4-scout-instruct-basic",
CoreModelId.llama4_scout_17b_16e_instruct.value, CoreModelId.llama4_scout_17b_16e_instruct.value,
@ -64,4 +67,4 @@ MODEL_ENTRIES = [
"context_length": 8192, "context_length": 8192,
}, },
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -17,11 +17,16 @@ LLM_MODEL_IDS = [
"gemini/gemini-2.5-pro", "gemini/gemini-2.5-pro",
] ]
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ MODEL_ENTRIES = (
ProviderModelEntry( [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
provider_model_id="gemini/text-embedding-004", + [
model_type=ModelType.embedding, ProviderModelEntry(
metadata={"embedding_dimension": 768, "context_length": 2048}, provider_model_id="gemini/text-embedding-004",
), model_type=ModelType.embedding,
] metadata={"embedding_dimension": 768, "context_length": 2048},
),
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -38,24 +38,18 @@ class GroqInferenceAdapter(LiteLLMOpenAIMixin):
provider_data_api_key_field="groq_api_key", provider_data_api_key_field="groq_api_key",
) )
self.config = config self.config = config
self._openai_client = None
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()
async def shutdown(self): async def shutdown(self):
await super().shutdown() await super().shutdown()
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
def _get_openai_client(self) -> AsyncOpenAI: def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client: return AsyncOpenAI(
self._openai_client = AsyncOpenAI( base_url=f"{self.config.url}/openai/v1",
base_url=f"{self.config.url}/openai/v1", api_key=self.get_api_key(),
api_key=self.config.api_key, )
)
return self._openai_client
async def openai_chat_completion( async def openai_chat_completion(
self, self,

View file

@ -10,6 +10,8 @@ from llama_stack.providers.utils.inference.model_registry import (
build_model_entry, build_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"groq/llama3-8b-8192", "groq/llama3-8b-8192",
@ -51,4 +53,4 @@ MODEL_ENTRIES = [
"groq/meta-llama/llama-4-maverick-17b-128e-instruct", "groq/meta-llama/llama-4-maverick-17b-128e-instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -11,6 +11,9 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = []
# https://docs.nvidia.com/nim/large-language-models/latest/supported-llm-agnostic-architectures.html
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta/llama3-8b-instruct", "meta/llama3-8b-instruct",
@ -99,4 +102,4 @@ MODEL_ENTRIES = [
), ),
# TODO(mf): how do we handle Nemotron models? # TODO(mf): how do we handle Nemotron models?
# "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct",
] ] + SAFETY_MODELS_ENTRIES

View file

@ -48,16 +48,20 @@ EMBEDDING_MODEL_IDS: dict[str, EmbeddingModelInfo] = {
"text-embedding-3-small": EmbeddingModelInfo(1536, 8192), "text-embedding-3-small": EmbeddingModelInfo(1536, 8192),
"text-embedding-3-large": EmbeddingModelInfo(3072, 8192), "text-embedding-3-large": EmbeddingModelInfo(3072, 8192),
} }
SAFETY_MODELS_ENTRIES = []
MODEL_ENTRIES = (
MODEL_ENTRIES = [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS] + [ [ProviderModelEntry(provider_model_id=m) for m in LLM_MODEL_IDS]
ProviderModelEntry( + [
provider_model_id=model_id, ProviderModelEntry(
model_type=ModelType.embedding, provider_model_id=model_id,
metadata={ model_type=ModelType.embedding,
"embedding_dimension": model_info.embedding_dimension, metadata={
"context_length": model_info.context_length, "embedding_dimension": model_info.embedding_dimension,
}, "context_length": model_info.context_length,
) },
for model_id, model_info in EMBEDDING_MODEL_IDS.items() )
] for model_id, model_info in EMBEDDING_MODEL_IDS.items()
]
+ SAFETY_MODELS_ENTRIES
)

View file

@ -59,9 +59,6 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
# if we do not set this, users will be exposed to the # if we do not set this, users will be exposed to the
# litellm specific model names, an abstraction leak. # litellm specific model names, an abstraction leak.
self.is_openai_compat = True self.is_openai_compat = True
self._openai_client = AsyncOpenAI(
api_key=self.config.api_key,
)
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()
@ -69,6 +66,11 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
async def shutdown(self) -> None: async def shutdown(self) -> None:
await super().shutdown() await super().shutdown()
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(
api_key=self.get_api_key(),
)
async def openai_completion( async def openai_completion(
self, self,
model: str, model: str,
@ -120,7 +122,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
user=user, user=user,
suffix=suffix, suffix=suffix,
) )
return await self._openai_client.completions.create(**params) return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion( async def openai_chat_completion(
self, self,
@ -176,7 +178,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
return await self._openai_client.chat.completions.create(**params) return await self._get_openai_client().chat.completions.create(**params)
async def openai_embeddings( async def openai_embeddings(
self, self,
@ -204,7 +206,7 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin):
params["user"] = user params["user"] = user
# Call OpenAI embeddings API # Call OpenAI embeddings API
response = await self._openai_client.embeddings.create(**params) response = await self._get_openai_client().embeddings.create(**params)
data = [] data = []
for i, embedding_data in enumerate(response.data): for i, embedding_data in enumerate(response.data):

View file

@ -11,7 +11,7 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.inference import OpenAIEmbeddingsResponse from llama_stack.apis.inference import OpenAIEmbeddingsResponse
# from llama_stack.providers.datatypes import ModelsProtocolPrivate # from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper, build_hf_repo_model_entry
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin, OpenAIChatCompletionToLlamaStackMixin,
OpenAICompletionToLlamaStackMixin, OpenAICompletionToLlamaStackMixin,
@ -25,6 +25,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import RunpodImplConfig from .config import RunpodImplConfig
# https://docs.runpod.io/serverless/vllm/overview#compatible-models
# https://github.com/runpod-workers/worker-vllm/blob/main/README.md#compatible-model-architectures
RUNPOD_SUPPORTED_MODELS = { RUNPOD_SUPPORTED_MODELS = {
"Llama3.1-8B": "meta-llama/Llama-3.1-8B", "Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B", "Llama3.1-70B": "meta-llama/Llama-3.1-70B",
@ -40,6 +42,14 @@ RUNPOD_SUPPORTED_MODELS = {
"Llama3.2-3B": "meta-llama/Llama-3.2-3B", "Llama3.2-3B": "meta-llama/Llama-3.2-3B",
} }
SAFETY_MODELS_ENTRIES = []
# Create MODEL_ENTRIES from RUNPOD_SUPPORTED_MODELS for compatibility with starter template
MODEL_ENTRIES = [
build_hf_repo_model_entry(provider_model_id, model_descriptor)
for provider_model_id, model_descriptor in RUNPOD_SUPPORTED_MODELS.items()
] + SAFETY_MODELS_ENTRIES
class RunpodInferenceAdapter( class RunpodInferenceAdapter(
ModelRegistryHelper, ModelRegistryHelper,

View file

@ -9,6 +9,14 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"sambanova/Meta-Llama-3.1-8B-Instruct", "sambanova/Meta-Llama-3.1-8B-Instruct",
@ -46,8 +54,4 @@ MODEL_ENTRIES = [
"sambanova/Llama-4-Maverick-17B-128E-Instruct", "sambanova/Llama-4-Maverick-17B-128E-Instruct",
CoreModelId.llama4_maverick_17b_128e_instruct.value, CoreModelId.llama4_maverick_17b_128e_instruct.value,
), ),
build_hf_repo_model_entry( ] + SAFETY_MODELS_ENTRIES
"sambanova/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
]

View file

@ -7,6 +7,7 @@
import json import json
from collections.abc import Iterable from collections.abc import Iterable
import requests
from openai.types.chat import ( from openai.types.chat import (
ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage,
) )
@ -56,6 +57,7 @@ from llama_stack.apis.inference import (
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.apis.models import Model
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
@ -176,10 +178,11 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: SambaNovaImplConfig): def __init__(self, config: SambaNovaImplConfig):
self.config = config self.config = config
self.environment_available_models = []
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
model_entries=MODEL_ENTRIES, model_entries=MODEL_ENTRIES,
api_key_from_config=self.config.api_key, api_key_from_config=self.config.api_key.get_secret_value() if self.config.api_key else None,
provider_data_api_key_field="sambanova_api_key", provider_data_api_key_field="sambanova_api_key",
) )
@ -246,6 +249,22 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
**get_sampling_options(request.sampling_params), **get_sampling_options(request.sampling_params),
} }
async def register_model(self, model: Model) -> Model:
model_id = self.get_provider_model_id(model.provider_resource_id)
list_models_url = self.config.url + "/models"
if len(self.environment_available_models) == 0:
try:
response = requests.get(list_models_url)
response.raise_for_status()
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Request to {list_models_url} failed") from e
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
if model_id.split("sambanova/")[-1] not in self.environment_available_models:
logger.warning(f"Model {model_id} not available in {list_models_url}")
return model
async def initialize(self): async def initialize(self):
await super().initialize() await super().initialize()

View file

@ -11,6 +11,16 @@ from llama_stack.providers.utils.inference.model_registry import (
build_hf_repo_model_entry, build_hf_repo_model_entry,
) )
SAFETY_MODELS_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
]
MODEL_ENTRIES = [ MODEL_ENTRIES = [
build_hf_repo_model_entry( build_hf_repo_model_entry(
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo", "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
@ -40,14 +50,6 @@ MODEL_ENTRIES = [
"meta-llama/Llama-3.3-70B-Instruct-Turbo", "meta-llama/Llama-3.3-70B-Instruct-Turbo",
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
), ),
build_hf_repo_model_entry(
"meta-llama/Meta-Llama-Guard-3-8B",
CoreModelId.llama_guard_3_8b.value,
),
build_hf_repo_model_entry(
"meta-llama/Llama-Guard-3-11B-Vision-Turbo",
CoreModelId.llama_guard_3_11b_vision.value,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval", provider_model_id="togethercomputer/m2-bert-80M-8k-retrieval",
model_type=ModelType.embedding, model_type=ModelType.embedding,
@ -78,4 +80,4 @@ MODEL_ENTRIES = [
"together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", "together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8",
], ],
), ),
] ] + SAFETY_MODELS_ENTRIES

View file

@ -68,19 +68,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
def __init__(self, config: TogetherImplConfig) -> None: def __init__(self, config: TogetherImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES) ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
self.config = config self.config = config
self._client = None
self._openai_client = None
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
if self._client: pass
# Together client has no close method, so just set to None
self._client = None
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
async def completion( async def completion(
self, self,
@ -108,29 +101,25 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
return await self._nonstream_completion(request) return await self._nonstream_completion(request)
def _get_client(self) -> AsyncTogether: def _get_client(self) -> AsyncTogether:
if not self._client: together_api_key = None
together_api_key = None config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None if config_api_key:
if config_api_key: together_api_key = config_api_key
together_api_key = config_api_key else:
else: provider_data = self.get_request_provider_data()
provider_data = self.get_request_provider_data() if provider_data is None or not provider_data.together_api_key:
if provider_data is None or not provider_data.together_api_key: raise ValueError(
raise ValueError( 'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}'
'Pass Together API Key in the header X-LlamaStack-Provider-Data as { "together_api_key": <your api key>}' )
) together_api_key = provider_data.together_api_key
together_api_key = provider_data.together_api_key return AsyncTogether(api_key=together_api_key)
self._client = AsyncTogether(api_key=together_api_key)
return self._client
def _get_openai_client(self) -> AsyncOpenAI: def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client: together_client = self._get_client().client
together_client = self._get_client().client return AsyncOpenAI(
self._openai_client = AsyncOpenAI( base_url=together_client.base_url,
base_url=together_client.base_url, api_key=together_client.api_key,
api_key=together_client.api_key, )
)
return self._openai_client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)

View file

@ -33,6 +33,7 @@ CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"
class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData): class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProviderData):
def __init__(self, config: SambaNovaSafetyConfig) -> None: def __init__(self, config: SambaNovaSafetyConfig) -> None:
self.config = config self.config = config
self.environment_available_models = []
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -54,18 +55,18 @@ class SambaNovaSafetyAdapter(Safety, ShieldsProtocolPrivate, NeedsRequestProvide
async def register_shield(self, shield: Shield) -> None: async def register_shield(self, shield: Shield) -> None:
list_models_url = self.config.url + "/models" list_models_url = self.config.url + "/models"
try: if len(self.environment_available_models) == 0:
response = requests.get(list_models_url) try:
response.raise_for_status() response = requests.get(list_models_url)
except requests.exceptions.RequestException as e: response.raise_for_status()
raise RuntimeError(f"Request to {list_models_url} failed") from e except requests.exceptions.RequestException as e:
available_models = [model.get("id") for model in response.json().get("data", {})] raise RuntimeError(f"Request to {list_models_url} failed") from e
self.environment_available_models = [model.get("id") for model in response.json().get("data", {})]
if ( if (
len(available_models) == 0 "guard" not in shield.provider_resource_id.lower()
or "guard" not in shield.provider_resource_id.lower() or shield.provider_resource_id.split("sambanova/")[-1] not in self.environment_available_models
or shield.provider_resource_id.split("sambanova/")[-1] not in available_models
): ):
raise ValueError(f"Shield {shield.provider_resource_id} not found in SambaNova") logger.warning(f"Shield {shield.provider_resource_id} not available in {list_models_url}")
async def run_shield( async def run_shield(
self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None self, shield_id: str, messages: list[Message], params: dict[str, Any] | None = None

View file

@ -61,6 +61,11 @@ class MilvusIndex(EmbeddingIndex):
self.consistency_level = consistency_level self.consistency_level = consistency_level
self.kvstore = kvstore self.kvstore = kvstore
async def initialize(self):
# MilvusIndex does not require explicit initialization
# TODO: could move collection creation into initialization but it is not really necessary
pass
async def delete(self): async def delete(self):
if await asyncio.to_thread(self.client.has_collection, self.collection_name): if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name) await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
@ -199,6 +204,9 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
if vector_db_id in self.cache: if vector_db_id in self.cache:
return self.cache[vector_db_id] return self.cache[vector_db_id]
if self.vector_db_store is None:
raise ValueError(f"Vector DB {vector_db_id} not found")
vector_db = await self.vector_db_store.get_vector_db(vector_db_id) vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db: if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found") raise ValueError(f"Vector DB {vector_db_id} not found")

View file

@ -44,6 +44,7 @@ def build_hf_repo_model_entry(
] ]
if additional_aliases: if additional_aliases:
aliases.extend(additional_aliases) aliases.extend(additional_aliases)
aliases = [alias for alias in aliases if alias is not None]
return ProviderModelEntry( return ProviderModelEntry(
provider_model_id=provider_model_id, provider_model_id=provider_model_id,
aliases=aliases, aliases=aliases,
@ -82,35 +83,35 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
def get_llama_model(self, provider_model_id: str) -> str | None: def get_llama_model(self, provider_model_id: str) -> str | None:
return self.provider_id_to_llama_model_map.get(provider_model_id, None) return self.provider_id_to_llama_model_map.get(provider_model_id, None)
async def query_available_models(self) -> list[str]: async def check_model_availability(self, model: str) -> bool:
""" """
Return a list of available models. Check if a specific model is available from the provider (non-static check).
This is for subclassing purposes, so providers can lookup a list of This is for subclassing purposes, so providers can check if a specific
of currently available models. model is currently available for use through dynamic means (e.g., API calls).
This is combined with the statically configured model entries in This method should NOT check statically configured model entries in
`self.alias_to_provider_id_map` to determine which models are `self.alias_to_provider_id_map` - that is handled separately in register_model.
available for registration.
Default implementation returns no models. Default implementation returns False (no dynamic models available).
:return: A list of model identifiers (provider_model_ids). :param model: The model identifier to check.
:return: True if the model is available dynamically, False otherwise.
""" """
return [] return False
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
# Check if model is supported in static configuration # Check if model is supported in static configuration
supported_model_id = self.get_provider_model_id(model.provider_resource_id) supported_model_id = self.get_provider_model_id(model.provider_resource_id)
# If not found in static config, check if it's available from provider # If not found in static config, check if it's available dynamically from provider
if not supported_model_id: if not supported_model_id:
available_models = await self.query_available_models() if await self.check_model_availability(model.provider_resource_id):
if model.provider_resource_id in available_models:
supported_model_id = model.provider_resource_id supported_model_id = model.provider_resource_id
else: else:
# Combine static and dynamic models for error message # note: we cannot provide a complete list of supported models without
all_supported_models = list(self.alias_to_provider_id_map.keys()) + available_models # getting a complete list from the provider, so we return "..."
all_supported_models = [*self.alias_to_provider_id_map.keys(), "..."]
raise UnsupportedModelError(model.provider_resource_id, all_supported_models) raise UnsupportedModelError(model.provider_resource_id, all_supported_models)
provider_resource_id = self.get_provider_model_id(model.model_id) provider_resource_id = self.get_provider_model_id(model.model_id)
@ -118,7 +119,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model # embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
if provider_resource_id: if provider_resource_id:
if provider_resource_id != supported_model_id: # be idemopotent, only reject differences if provider_resource_id != supported_model_id: # be idempotent, only reject differences
raise ValueError( raise ValueError(
f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first." f"Model id '{model.model_id}' is already registered. Please use a different id or unregister it first."
) )

View file

@ -39,22 +39,10 @@ SQL_OPTIMIZED_POLICY = [
class SqlRecord(ProtectedResource): class SqlRecord(ProtectedResource):
"""Simple ProtectedResource implementation for SQL records.""" def __init__(self, record_id: str, table_name: str, owner: User):
def __init__(self, record_id: str, table_name: str, access_attributes: dict[str, list[str]] | None = None):
self.type = f"sql_record::{table_name}" self.type = f"sql_record::{table_name}"
self.identifier = record_id self.identifier = record_id
self.owner = owner
if access_attributes:
self.owner = User(
principal="system",
attributes=access_attributes,
)
else:
self.owner = User(
principal="system_public",
attributes=None,
)
class AuthorizedSqlStore: class AuthorizedSqlStore:
@ -101,22 +89,27 @@ class AuthorizedSqlStore:
async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None: async def create_table(self, table: str, schema: Mapping[str, ColumnType | ColumnDefinition]) -> None:
"""Create a table with built-in access control support.""" """Create a table with built-in access control support."""
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
enhanced_schema = dict(schema) enhanced_schema = dict(schema)
if "access_attributes" not in enhanced_schema: if "access_attributes" not in enhanced_schema:
enhanced_schema["access_attributes"] = ColumnType.JSON enhanced_schema["access_attributes"] = ColumnType.JSON
if "owner_principal" not in enhanced_schema:
enhanced_schema["owner_principal"] = ColumnType.STRING
await self.sql_store.create_table(table, enhanced_schema) await self.sql_store.create_table(table, enhanced_schema)
await self.sql_store.add_column_if_not_exists(table, "access_attributes", ColumnType.JSON)
await self.sql_store.add_column_if_not_exists(table, "owner_principal", ColumnType.STRING)
async def insert(self, table: str, data: Mapping[str, Any]) -> None: async def insert(self, table: str, data: Mapping[str, Any]) -> None:
"""Insert a row with automatic access control attribute capture.""" """Insert a row with automatic access control attribute capture."""
enhanced_data = dict(data) enhanced_data = dict(data)
current_user = get_authenticated_user() current_user = get_authenticated_user()
if current_user and current_user.attributes: if current_user:
enhanced_data["owner_principal"] = current_user.principal
enhanced_data["access_attributes"] = current_user.attributes enhanced_data["access_attributes"] = current_user.attributes
else: else:
enhanced_data["owner_principal"] = None
enhanced_data["access_attributes"] = None enhanced_data["access_attributes"] = None
await self.sql_store.insert(table, enhanced_data) await self.sql_store.insert(table, enhanced_data)
@ -146,9 +139,12 @@ class AuthorizedSqlStore:
for row in rows.data: for row in rows.data:
stored_access_attrs = row.get("access_attributes") stored_access_attrs = row.get("access_attributes")
stored_owner_principal = row.get("owner_principal") or ""
record_id = row.get("id", "unknown") record_id = row.get("id", "unknown")
sql_record = SqlRecord(str(record_id), table, stored_access_attrs) sql_record = SqlRecord(
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(policy, Action.READ, sql_record, current_user): if is_action_allowed(policy, Action.READ, sql_record, current_user):
filtered_rows.append(row) filtered_rows.append(row)
@ -186,8 +182,10 @@ class AuthorizedSqlStore:
Only applies SQL filtering for the default policy to ensure correctness. Only applies SQL filtering for the default policy to ensure correctness.
For custom policies, uses conservative filtering to avoid blocking legitimate access. For custom policies, uses conservative filtering to avoid blocking legitimate access.
""" """
current_user = get_authenticated_user()
if not policy or policy == SQL_OPTIMIZED_POLICY: if not policy or policy == SQL_OPTIMIZED_POLICY:
return self._build_default_policy_where_clause() return self._build_default_policy_where_clause(current_user)
else: else:
return self._build_conservative_where_clause() return self._build_conservative_where_clause()
@ -227,29 +225,27 @@ class AuthorizedSqlStore:
def _get_public_access_conditions(self) -> list[str]: def _get_public_access_conditions(self) -> list[str]:
"""Get the SQL conditions for public access.""" """Get the SQL conditions for public access."""
# Public records are records that have no owner_principal or access_attributes
conditions = ["owner_principal = ''"]
if self.database_type == SqlStoreType.postgres: if self.database_type == SqlStoreType.postgres:
# Postgres stores JSON null as 'null' # Postgres stores JSON null as 'null'
return ["access_attributes::text = 'null'"] conditions.append("access_attributes::text = 'null'")
elif self.database_type == SqlStoreType.sqlite: elif self.database_type == SqlStoreType.sqlite:
return ["access_attributes = 'null'"] conditions.append("access_attributes = 'null'")
else: else:
raise ValueError(f"Unsupported database type: {self.database_type}") raise ValueError(f"Unsupported database type: {self.database_type}")
return conditions
def _build_default_policy_where_clause(self) -> str: def _build_default_policy_where_clause(self, current_user: User | None) -> str:
"""Build SQL WHERE clause for the default policy. """Build SQL WHERE clause for the default policy.
Default policy: permit all actions when user in owners [roles, teams, projects, namespaces] Default policy: permit all actions when user in owners [roles, teams, projects, namespaces]
This means user must match ALL attribute categories that exist in the resource. This means user must match ALL attribute categories that exist in the resource.
""" """
current_user = get_authenticated_user()
base_conditions = self._get_public_access_conditions() base_conditions = self._get_public_access_conditions()
if not current_user or not current_user.attributes: user_attr_conditions = []
# Only allow public records
return f"({' OR '.join(base_conditions)})"
else:
user_attr_conditions = []
if current_user and current_user.attributes:
for attr_key, user_values in current_user.attributes.items(): for attr_key, user_values in current_user.attributes.items():
if user_values: if user_values:
value_conditions = [] value_conditions = []
@ -269,7 +265,7 @@ class AuthorizedSqlStore:
all_requirements_met = f"({' AND '.join(user_attr_conditions)})" all_requirements_met = f"({' AND '.join(user_attr_conditions)})"
base_conditions.append(all_requirements_met) base_conditions.append(all_requirements_met)
return f"({' OR '.join(base_conditions)})" return f"({' OR '.join(base_conditions)})"
def _build_conservative_where_clause(self) -> str: def _build_conservative_where_clause(self) -> str:
"""Conservative SQL filtering for custom policies. """Conservative SQL filtering for custom policies.

View file

@ -244,35 +244,41 @@ class SqlAlchemySqlStoreImpl(SqlStore):
engine = create_async_engine(self.config.engine_str) engine = create_async_engine(self.config.engine_str)
try: try:
inspector = inspect(engine)
table_names = inspector.get_table_names()
if table not in table_names:
return
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
if column_name in column_names:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
async with engine.begin() as conn: async with engine.begin() as conn:
def check_column_exists(sync_conn):
inspector = inspect(sync_conn)
table_names = inspector.get_table_names()
if table not in table_names:
return False, False # table doesn't exist, column doesn't exist
existing_columns = inspector.get_columns(table)
column_names = [col["name"] for col in existing_columns]
return True, column_name in column_names # table exists, column exists or not
table_exists, column_exists = await conn.run_sync(check_column_exists)
if not table_exists or column_exists:
return
sqlalchemy_type = TYPE_MAPPING.get(column_type)
if not sqlalchemy_type:
raise ValueError(f"Unsupported column type '{column_type}' for column '{column_name}'.")
# Create the ALTER TABLE statement
# Note: We need to get the dialect-specific type name
dialect = engine.dialect
type_impl = sqlalchemy_type()
compiled_type = type_impl.compile(dialect=dialect)
nullable_clause = "" if nullable else " NOT NULL"
add_column_sql = text(f"ALTER TABLE {table} ADD COLUMN {column_name} {compiled_type}{nullable_clause}")
await conn.execute(add_column_sql) await conn.execute(add_column_sql)
except Exception: except Exception as e:
# If any error occurs during migration, log it but don't fail # If any error occurs during migration, log it but don't fail
# The table creation will handle adding the column # The table creation will handle adding the column
logger.error(f"Error adding column {column_name} to table {table}: {e}")
pass pass

View file

@ -9,14 +9,12 @@ import inspect
import json import json
from collections.abc import AsyncGenerator, Callable from collections.abc import AsyncGenerator, Callable
from functools import wraps from functools import wraps
from typing import Any, TypeVar from typing import Any
from pydantic import BaseModel from pydantic import BaseModel
from llama_stack.models.llama.datatypes import Primitive from llama_stack.models.llama.datatypes import Primitive
T = TypeVar("T")
def serialize_value(value: Any) -> Primitive: def serialize_value(value: Any) -> Primitive:
return str(_prepare_for_json(value)) return str(_prepare_for_json(value))
@ -44,7 +42,7 @@ def _prepare_for_json(value: Any) -> str:
return str(value) return str(value)
def trace_protocol(cls: type[T]) -> type[T]: def trace_protocol[T](cls: type[T]) -> type[T]:
""" """
A class decorator that automatically traces all methods in a protocol/base class A class decorator that automatically traces all methods in a protocol/base class
and its inheriting classes. and its inheriting classes.

View file

@ -39,6 +39,9 @@ providers:
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/open-benchmark}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:

View file

@ -144,6 +144,9 @@ providers:
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
kvstore:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
- provider_id: ${env.ENABLE_MILVUS:=__disabled__} - provider_id: ${env.ENABLE_MILVUS:=__disabled__}
provider_type: inline::milvus provider_type: inline::milvus
config: config:
@ -256,11 +259,46 @@ inference_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db
models: models:
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama3.1-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-3.3-70b
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-3.3-70b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-4-scout-17b-16e-instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct
provider_id: ${env.ENABLE_CEREBRAS:=__disabled__}
provider_model_id: llama-4-scout-17b-16e-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__} model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__} provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__} provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__}
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__}
provider_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_model_id: ${env.SAFETY_MODEL:=__disabled__}
model_type: llm
- metadata: - metadata:
embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384} embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384}
model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__} model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}
@ -342,26 +380,6 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
@ -389,6 +407,26 @@ models:
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: nomic-ai/nomic-embed-text-v1.5 provider_model_id: nomic-ai/nomic-embed-text-v1.5
model_type: embedding model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-8b
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
@ -459,26 +497,6 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Meta-Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: - metadata:
embedding_dimension: 768 embedding_dimension: 768
context_length: 8192 context_length: 8192
@ -523,6 +541,264 @@ models:
provider_id: ${env.ENABLE_TOGETHER:=__disabled__} provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8
model_type: llm model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision
provider_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-8b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-8b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-70b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-70b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-405b-instruct-v1:0
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_BEDROCK:=__disabled__}
provider_model_id: meta.llama3-1-405b-instruct-v1:0
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-70b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-405b-instruct
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_DATABRICKS:=__disabled__}
provider_model_id: databricks-meta-llama-3-1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-8b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-8b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-405b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.1-405b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-1b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-1B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-1b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-3b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-3b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-11b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-90b-vision-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.3-70b-instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: meta/llama-3.3-70b-instruct
model_type: llm
- metadata:
embedding_dimension: 2048
context_length: 8192
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/llama-3.2-nv-embedqa-1b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-e5-v5
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-e5-v5
model_type: embedding
- metadata:
embedding_dimension: 4096
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-mistral-7b-v2
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: nvidia/nv-embedqa-mistral-7b-v2
model_type: embedding
- metadata:
embedding_dimension: 1024
context_length: 512
model_id: ${env.ENABLE_NVIDIA:=__disabled__}/snowflake/arctic-embed-l
provider_id: ${env.ENABLE_NVIDIA:=__disabled__}
provider_model_id: snowflake/arctic-embed-l
model_type: embedding
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-8B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-70B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp8
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp8
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp16
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.1-405B-Instruct:bf16-mp16
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-1B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-1B
model_type: llm
- metadata: {}
model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-3B
provider_id: ${env.ENABLE_RUNPOD:=__disabled__}
provider_model_id: Llama3.2-3B
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o
provider_id: ${env.ENABLE_OPENAI:=__disabled__} provider_id: ${env.ENABLE_OPENAI:=__disabled__}
@ -894,7 +1170,25 @@ models:
model_id: all-MiniLM-L6-v2 model_id: all-MiniLM-L6-v2
provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers}
model_type: embedding model_type: embedding
shields: [] shields:
- shield_id: ${env.ENABLE_OLLAMA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=llama-guard3:1b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-8b}
- shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_FIREWORKS:=__disabled__}/${env.SAFETY_MODEL:=accounts/fireworks/models/llama-guard-3-11b-vision}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}
- shield_id: ${env.ENABLE_TOGETHER:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_TOGETHER:=__disabled__}/${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-11B-Vision-Turbo}
- shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}
provider_id: llama-guard
provider_shield_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/${env.SAFETY_MODEL:=sambanova/Meta-Llama-Guard-3-8B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -12,6 +12,7 @@ from llama_stack.distribution.datatypes import (
ModelInput, ModelInput,
Provider, Provider,
ProviderSpec, ProviderSpec,
ShieldInput,
ToolGroupInput, ToolGroupInput,
) )
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -31,24 +32,75 @@ from llama_stack.providers.registry.inference import available_providers
from llama_stack.providers.remote.inference.anthropic.models import ( from llama_stack.providers.remote.inference.anthropic.models import (
MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES, MODEL_ENTRIES as ANTHROPIC_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.anthropic.models import (
SAFETY_MODELS_ENTRIES as ANTHROPIC_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
MODEL_ENTRIES as BEDROCK_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.bedrock.models import (
SAFETY_MODELS_ENTRIES as BEDROCK_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
MODEL_ENTRIES as CEREBRAS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.cerebras.models import (
SAFETY_MODELS_ENTRIES as CEREBRAS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
MODEL_ENTRIES as DATABRICKS_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.databricks.databricks import (
SAFETY_MODELS_ENTRIES as DATABRICKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.fireworks.models import ( from llama_stack.providers.remote.inference.fireworks.models import (
MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES, MODEL_ENTRIES as FIREWORKS_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.fireworks.models import (
SAFETY_MODELS_ENTRIES as FIREWORKS_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.gemini.models import ( from llama_stack.providers.remote.inference.gemini.models import (
MODEL_ENTRIES as GEMINI_MODEL_ENTRIES, MODEL_ENTRIES as GEMINI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.gemini.models import (
SAFETY_MODELS_ENTRIES as GEMINI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.groq.models import ( from llama_stack.providers.remote.inference.groq.models import (
MODEL_ENTRIES as GROQ_MODEL_ENTRIES, MODEL_ENTRIES as GROQ_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.groq.models import (
SAFETY_MODELS_ENTRIES as GROQ_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
MODEL_ENTRIES as NVIDIA_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.nvidia.models import (
SAFETY_MODELS_ENTRIES as NVIDIA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.openai.models import ( from llama_stack.providers.remote.inference.openai.models import (
MODEL_ENTRIES as OPENAI_MODEL_ENTRIES, MODEL_ENTRIES as OPENAI_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.openai.models import (
SAFETY_MODELS_ENTRIES as OPENAI_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
MODEL_ENTRIES as RUNPOD_MODEL_ENTRIES,
)
from llama_stack.providers.remote.inference.runpod.runpod import (
SAFETY_MODELS_ENTRIES as RUNPOD_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.sambanova.models import ( from llama_stack.providers.remote.inference.sambanova.models import (
MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES, MODEL_ENTRIES as SAMBANOVA_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.sambanova.models import (
SAFETY_MODELS_ENTRIES as SAMBANOVA_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.inference.together.models import ( from llama_stack.providers.remote.inference.together.models import (
MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES, MODEL_ENTRIES as TOGETHER_MODEL_ENTRIES,
) )
from llama_stack.providers.remote.inference.together.models import (
SAFETY_MODELS_ENTRIES as TOGETHER_SAFETY_MODELS_ENTRIES,
)
from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig from llama_stack.providers.remote.vector_io.chroma.config import ChromaVectorIOConfig
from llama_stack.providers.remote.vector_io.pgvector.config import ( from llama_stack.providers.remote.vector_io.pgvector.config import (
PGVectorVectorIOConfig, PGVectorVectorIOConfig,
@ -72,6 +124,11 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
"gemini": GEMINI_MODEL_ENTRIES, "gemini": GEMINI_MODEL_ENTRIES,
"groq": GROQ_MODEL_ENTRIES, "groq": GROQ_MODEL_ENTRIES,
"sambanova": SAMBANOVA_MODEL_ENTRIES, "sambanova": SAMBANOVA_MODEL_ENTRIES,
"cerebras": CEREBRAS_MODEL_ENTRIES,
"bedrock": BEDROCK_MODEL_ENTRIES,
"databricks": DATABRICKS_MODEL_ENTRIES,
"nvidia": NVIDIA_MODEL_ENTRIES,
"runpod": RUNPOD_MODEL_ENTRIES,
} }
# Special handling for providers with dynamic model entries # Special handling for providers with dynamic model entries
@ -81,6 +138,10 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}", provider_model_id="${env.OLLAMA_INFERENCE_MODEL:=__disabled__}",
model_type=ModelType.llm, model_type=ModelType.llm,
), ),
ProviderModelEntry(
provider_model_id="${env.SAFETY_MODEL:=__disabled__}",
model_type=ModelType.llm,
),
ProviderModelEntry( ProviderModelEntry(
provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}", provider_model_id="${env.OLLAMA_EMBEDDING_MODEL:=__disabled__}",
model_type=ModelType.embedding, model_type=ModelType.embedding,
@ -100,6 +161,35 @@ def _get_model_entries_for_provider(provider_type: str) -> list[ProviderModelEnt
return model_entries_map.get(provider_type, []) return model_entries_map.get(provider_type, [])
def _get_model_safety_entries_for_provider(provider_type: str) -> list[ProviderModelEntry]:
"""Get model entries for a specific provider type."""
safety_model_entries_map = {
"openai": OPENAI_SAFETY_MODELS_ENTRIES,
"fireworks": FIREWORKS_SAFETY_MODELS_ENTRIES,
"together": TOGETHER_SAFETY_MODELS_ENTRIES,
"anthropic": ANTHROPIC_SAFETY_MODELS_ENTRIES,
"gemini": GEMINI_SAFETY_MODELS_ENTRIES,
"groq": GROQ_SAFETY_MODELS_ENTRIES,
"sambanova": SAMBANOVA_SAFETY_MODELS_ENTRIES,
"cerebras": CEREBRAS_SAFETY_MODELS_ENTRIES,
"bedrock": BEDROCK_SAFETY_MODELS_ENTRIES,
"databricks": DATABRICKS_SAFETY_MODELS_ENTRIES,
"nvidia": NVIDIA_SAFETY_MODELS_ENTRIES,
"runpod": RUNPOD_SAFETY_MODELS_ENTRIES,
}
# Special handling for providers with dynamic model entries
if provider_type == "ollama":
return [
ProviderModelEntry(
provider_model_id="llama-guard3:1b",
model_type=ModelType.llm,
),
]
return safety_model_entries_map.get(provider_type, [])
def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]: def _get_config_for_provider(provider_spec: ProviderSpec) -> dict[str, Any]:
"""Get configuration for a provider using its adapter's config class.""" """Get configuration for a provider using its adapter's config class."""
config_class = instantiate_class_type(provider_spec.config_class) config_class = instantiate_class_type(provider_spec.config_class)
@ -155,6 +245,31 @@ def get_remote_inference_providers() -> tuple[list[Provider], dict[str, list[Pro
return inference_providers, available_models return inference_providers, available_models
# build a list of shields for all possible providers
def get_shields_for_providers(providers: list[Provider]) -> list[ShieldInput]:
shields = []
for provider in providers:
provider_type = provider.provider_type.split("::")[1]
safety_model_entries = _get_model_safety_entries_for_provider(provider_type)
if len(safety_model_entries) == 0:
continue
if provider.provider_id:
shield_id = provider.provider_id
else:
raise ValueError(f"Provider {provider.provider_type} has no provider_id")
for safety_model_entry in safety_model_entries:
print(f"provider.provider_id: {provider.provider_id}")
print(f"safety_model_entry.provider_model_id: {safety_model_entry.provider_model_id}")
shields.append(
ShieldInput(
provider_id="llama-guard",
shield_id=shield_id,
provider_shield_id=f"{provider.provider_id}/${{env.SAFETY_MODEL:={safety_model_entry.provider_model_id}}}",
)
)
return shields
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
remote_inference_providers, available_models = get_remote_inference_providers() remote_inference_providers, available_models = get_remote_inference_providers()
@ -192,6 +307,8 @@ def get_distribution_template() -> DistributionTemplate:
), ),
] ]
shields = get_shields_for_providers(remote_inference_providers)
providers = { providers = {
"inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]), "inference": ([p.provider_type for p in remote_inference_providers] + ["inline::sentence-transformers"]),
"vector_io": ([p.provider_type for p in vector_io_providers]), "vector_io": ([p.provider_type for p in vector_io_providers]),
@ -266,9 +383,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models + [embedding_model], default_models=default_models + [embedding_model],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
# TODO: add a way to enable/disable shields on the fly # TODO: add a way to enable/disable shields on the fly
# default_shields=[ default_shields=shields,
# ShieldInput(provider_id="llama-guard", shield_id="${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-8B}")
# ],
), ),
}, },
run_config_env_vars={ run_config_env_vars={

View file

@ -0,0 +1,82 @@
"use client";
import { useEffect, useState } from "react";
import { useParams, useRouter } from "next/navigation";
import { useAuthClient } from "@/hooks/use-auth-client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
import { VectorStoreDetailView } from "@/components/vector-stores/vector-store-detail";
export default function VectorStoreDetailPage() {
const params = useParams();
const id = params.id as string;
const client = useAuthClient();
const router = useRouter();
const [store, setStore] = useState<VectorStore | null>(null);
const [files, setFiles] = useState<VectorStoreFile[]>([]);
const [isLoadingStore, setIsLoadingStore] = useState(true);
const [isLoadingFiles, setIsLoadingFiles] = useState(true);
const [errorStore, setErrorStore] = useState<Error | null>(null);
const [errorFiles, setErrorFiles] = useState<Error | null>(null);
useEffect(() => {
if (!id) {
setErrorStore(new Error("Vector Store ID is missing."));
setIsLoadingStore(false);
return;
}
const fetchStore = async () => {
setIsLoadingStore(true);
setErrorStore(null);
try {
const response = await client.vectorStores.retrieve(id);
setStore(response as VectorStore);
} catch (err) {
setErrorStore(
err instanceof Error
? err
: new Error("Failed to load vector store."),
);
} finally {
setIsLoadingStore(false);
}
};
fetchStore();
}, [id, client]);
useEffect(() => {
if (!id) {
setErrorFiles(new Error("Vector Store ID is missing."));
setIsLoadingFiles(false);
return;
}
const fetchFiles = async () => {
setIsLoadingFiles(true);
setErrorFiles(null);
try {
const result = await client.vectorStores.files.list(id as any);
setFiles((result as any).data);
} catch (err) {
setErrorFiles(
err instanceof Error ? err : new Error("Failed to load files."),
);
} finally {
setIsLoadingFiles(false);
}
};
fetchFiles();
}, [id]);
return (
<VectorStoreDetailView
store={store}
files={files}
isLoadingStore={isLoadingStore}
isLoadingFiles={isLoadingFiles}
errorStore={errorStore}
errorFiles={errorFiles}
id={id}
/>
);
}

View file

@ -0,0 +1,16 @@
"use client";
import React from "react";
import LogsLayout from "@/components/layout/logs-layout";
export default function VectorStoresLayout({
children,
}: {
children: React.ReactNode;
}) {
return (
<LogsLayout sectionLabel="Vector Stores" basePath="/logs/vector-stores">
{children}
</LogsLayout>
);
}

View file

@ -0,0 +1,121 @@
"use client";
import React from "react";
import { useAuthClient } from "@/hooks/use-auth-client";
import type {
ListVectorStoresResponse,
VectorStore,
} from "llama-stack-client/resources/vector-stores/vector-stores";
import { useRouter } from "next/navigation";
import { usePagination } from "@/hooks/use-pagination";
import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { Skeleton } from "@/components/ui/skeleton";
export default function VectorStoresPage() {
const client = useAuthClient();
const router = useRouter();
const {
data: stores,
status,
hasMore,
error,
loadMore,
} = usePagination<VectorStore>({
limit: 20,
order: "desc",
fetchFunction: async (client, params) => {
const response = await client.vectorStores.list({
after: params.after,
limit: params.limit,
order: params.order,
} as any);
return response as ListVectorStoresResponse;
},
errorMessagePrefix: "vector stores",
});
// Auto-load all pages for infinite scroll behavior (like Responses)
React.useEffect(() => {
if (status === "idle" && hasMore) {
loadMore();
}
}, [status, hasMore, loadMore]);
if (status === "loading") {
return (
<div className="space-y-2">
<Skeleton className="h-8 w-full" />
<Skeleton className="h-4 w-full" />
<Skeleton className="h-4 w-full" />
</div>
);
}
if (status === "error") {
return <div className="text-destructive">Error: {error?.message}</div>;
}
if (!stores || stores.length === 0) {
return <p>No vector stores found.</p>;
}
return (
<div className="overflow-auto flex-1 min-h-0">
<Table>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Name</TableHead>
<TableHead>Created</TableHead>
<TableHead>Completed</TableHead>
<TableHead>Cancelled</TableHead>
<TableHead>Failed</TableHead>
<TableHead>In Progress</TableHead>
<TableHead>Total</TableHead>
<TableHead>Usage Bytes</TableHead>
<TableHead>Provider ID</TableHead>
<TableHead>Provider Vector DB ID</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{stores.map((store) => {
const fileCounts = store.file_counts;
const metadata = store.metadata || {};
const providerId = metadata.provider_id ?? "";
const providerDbId = metadata.provider_vector_db_id ?? "";
return (
<TableRow
key={store.id}
onClick={() => router.push(`/logs/vector-stores/${store.id}`)}
className="cursor-pointer hover:bg-muted/50"
>
<TableCell>{store.id}</TableCell>
<TableCell>{store.name}</TableCell>
<TableCell>
{new Date(store.created_at * 1000).toLocaleString()}
</TableCell>
<TableCell>{fileCounts.completed}</TableCell>
<TableCell>{fileCounts.cancelled}</TableCell>
<TableCell>{fileCounts.failed}</TableCell>
<TableCell>{fileCounts.in_progress}</TableCell>
<TableCell>{fileCounts.total}</TableCell>
<TableCell>{store.usage_bytes}</TableCell>
<TableCell>{providerId}</TableCell>
<TableCell>{providerDbId}</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
);
}

View file

@ -1,6 +1,11 @@
"use client"; "use client";
import { MessageSquareText, MessagesSquare, MoveUpRight } from "lucide-react"; import {
MessageSquareText,
MessagesSquare,
MoveUpRight,
Database,
} from "lucide-react";
import Link from "next/link"; import Link from "next/link";
import { usePathname } from "next/navigation"; import { usePathname } from "next/navigation";
import { cn } from "@/lib/utils"; import { cn } from "@/lib/utils";
@ -28,6 +33,11 @@ const logItems = [
url: "/logs/responses", url: "/logs/responses",
icon: MessagesSquare, icon: MessagesSquare,
}, },
{
title: "Vector Stores",
url: "/logs/vector-stores",
icon: Database,
},
{ {
title: "Documentation", title: "Documentation",
url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html", url: "https://llama-stack.readthedocs.io/en/latest/references/api_reference/index.html",
@ -57,13 +67,13 @@ export function AppSidebar() {
className={cn( className={cn(
"justify-start", "justify-start",
isActive && isActive &&
"bg-gray-200 hover:bg-gray-200 text-primary hover:text-primary", "bg-gray-200 dark:bg-gray-700 hover:bg-gray-200 dark:hover:bg-gray-700 text-gray-900 dark:text-gray-100",
)} )}
> >
<Link href={item.url}> <Link href={item.url}>
<item.icon <item.icon
className={cn( className={cn(
isActive && "text-primary", isActive && "text-gray-900 dark:text-gray-100",
"mr-2 h-4 w-4", "mr-2 h-4 w-4",
)} )}
/> />

View file

@ -93,7 +93,9 @@ export function PropertyItem({
> >
<strong>{label}:</strong>{" "} <strong>{label}:</strong>{" "}
{typeof value === "string" || typeof value === "number" ? ( {typeof value === "string" || typeof value === "number" ? (
<span className="text-gray-900 font-medium">{value}</span> <span className="text-gray-900 dark:text-gray-100 font-medium">
{value}
</span>
) : ( ) : (
value value
)} )}
@ -112,7 +114,9 @@ export function PropertiesCard({ children }: PropertiesCardProps) {
<CardTitle>Properties</CardTitle> <CardTitle>Properties</CardTitle>
</CardHeader> </CardHeader>
<CardContent> <CardContent>
<ul className="space-y-2 text-sm text-gray-600">{children}</ul> <ul className="space-y-2 text-sm text-gray-600 dark:text-gray-400">
{children}
</ul>
</CardContent> </CardContent>
</Card> </Card>
); );

View file

@ -17,10 +17,10 @@ export const MessageBlock: React.FC<MessageBlockProps> = ({
}) => { }) => {
return ( return (
<div className={`mb-4 ${className}`}> <div className={`mb-4 ${className}`}>
<p className="py-1 font-semibold text-gray-800 mb-1"> <p className="py-1 font-semibold text-muted-foreground mb-1">
{label} {label}
{labelDetail && ( {labelDetail && (
<span className="text-xs text-gray-500 font-normal ml-1"> <span className="text-xs text-muted-foreground font-normal ml-1">
{labelDetail} {labelDetail}
</span> </span>
)} )}

View file

@ -0,0 +1,128 @@
"use client";
import type { VectorStore } from "llama-stack-client/resources/vector-stores/vector-stores";
import type { VectorStoreFile } from "llama-stack-client/resources/vector-stores/files";
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
import { Skeleton } from "@/components/ui/skeleton";
import {
DetailLoadingView,
DetailErrorView,
DetailNotFoundView,
DetailLayout,
PropertiesCard,
PropertyItem,
} from "@/components/layout/detail-layout";
import {
Table,
TableBody,
TableCaption,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
interface VectorStoreDetailViewProps {
store: VectorStore | null;
files: VectorStoreFile[];
isLoadingStore: boolean;
isLoadingFiles: boolean;
errorStore: Error | null;
errorFiles: Error | null;
id: string;
}
export function VectorStoreDetailView({
store,
files,
isLoadingStore,
isLoadingFiles,
errorStore,
errorFiles,
id,
}: VectorStoreDetailViewProps) {
const title = "Vector Store Details";
if (errorStore) {
return <DetailErrorView title={title} id={id} error={errorStore} />;
}
if (isLoadingStore) {
return <DetailLoadingView title={title} />;
}
if (!store) {
return <DetailNotFoundView title={title} id={id} />;
}
const mainContent = (
<>
<Card>
<CardHeader>
<CardTitle>Files</CardTitle>
</CardHeader>
<CardContent>
{isLoadingFiles ? (
<Skeleton className="h-4 w-full" />
) : errorFiles ? (
<div className="text-destructive text-sm">
Error loading files: {errorFiles.message}
</div>
) : files.length > 0 ? (
<Table>
<TableCaption>Files in this vector store</TableCaption>
<TableHeader>
<TableRow>
<TableHead>ID</TableHead>
<TableHead>Status</TableHead>
<TableHead>Created</TableHead>
<TableHead>Usage Bytes</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{files.map((file) => (
<TableRow key={file.id}>
<TableCell>{file.id}</TableCell>
<TableCell>{file.status}</TableCell>
<TableCell>
{new Date(file.created_at * 1000).toLocaleString()}
</TableCell>
<TableCell>{file.usage_bytes}</TableCell>
</TableRow>
))}
</TableBody>
</Table>
) : (
<p className="text-gray-500 italic text-sm">
No files in this vector store.
</p>
)}
</CardContent>
</Card>
</>
);
const sidebar = (
<PropertiesCard>
<PropertyItem label="ID" value={store.id} />
<PropertyItem label="Name" value={store.name || ""} />
<PropertyItem
label="Created"
value={new Date(store.created_at * 1000).toLocaleString()}
/>
<PropertyItem label="Status" value={store.status} />
<PropertyItem label="Total Files" value={store.file_counts.total} />
<PropertyItem label="Usage Bytes" value={store.usage_bytes} />
<PropertyItem
label="Provider ID"
value={(store.metadata.provider_id as string) || ""}
/>
<PropertyItem
label="Provider DB ID"
value={(store.metadata.provider_vector_db_id as string) || ""}
/>
</PropertiesCard>
);
return (
<DetailLayout title={title} mainContent={mainContent} sidebar={sidebar} />
);
}

View file

@ -15,7 +15,7 @@
"@radix-ui/react-tooltip": "^1.2.6", "@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"llama-stack-client": "0.2.13", "llama-stack-client": "^0.2.14",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",
@ -676,406 +676,6 @@
"tslib": "^2.4.0" "tslib": "^2.4.0"
} }
}, },
"node_modules/@esbuild/aix-ppc64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.25.5.tgz",
"integrity": "sha512-9o3TMmpmftaCMepOdA5k/yDw8SfInyzWWTjYTFCX3kPSDJMROQTb8jg+h9Cnwnmm1vOzvxN7gIfB5V2ewpjtGA==",
"cpu": [
"ppc64"
],
"license": "MIT",
"optional": true,
"os": [
"aix"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/android-arm": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/android-arm/-/android-arm-0.25.5.tgz",
"integrity": "sha512-AdJKSPeEHgi7/ZhuIPtcQKr5RQdo6OO2IL87JkianiMYMPbCtot9fxPbrMiBADOWWm3T2si9stAiVsGbTQFkbA==",
"cpu": [
"arm"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/android-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/android-arm64/-/android-arm64-0.25.5.tgz",
"integrity": "sha512-VGzGhj4lJO+TVGV1v8ntCZWJktV7SGCs3Pn1GRWI1SBFtRALoomm8k5E9Pmwg3HOAal2VDc2F9+PM/rEY6oIDg==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/android-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/android-x64/-/android-x64-0.25.5.tgz",
"integrity": "sha512-D2GyJT1kjvO//drbRT3Hib9XPwQeWd9vZoBJn+bu/lVsOZ13cqNdDeqIF/xQ5/VmWvMduP6AmXvylO/PIc2isw==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"android"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/darwin-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/darwin-arm64/-/darwin-arm64-0.25.5.tgz",
"integrity": "sha512-GtaBgammVvdF7aPIgH2jxMDdivezgFu6iKpmT+48+F8Hhg5J/sfnDieg0aeG/jfSvkYQU2/pceFPDKlqZzwnfQ==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/darwin-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/darwin-x64/-/darwin-x64-0.25.5.tgz",
"integrity": "sha512-1iT4FVL0dJ76/q1wd7XDsXrSW+oLoquptvh4CLR4kITDtqi2e/xwXwdCVH8hVHU43wgJdsq7Gxuzcs6Iq/7bxQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"darwin"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/freebsd-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-arm64/-/freebsd-arm64-0.25.5.tgz",
"integrity": "sha512-nk4tGP3JThz4La38Uy/gzyXtpkPW8zSAmoUhK9xKKXdBCzKODMc2adkB2+8om9BDYugz+uGV7sLmpTYzvmz6Sw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"freebsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/freebsd-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/freebsd-x64/-/freebsd-x64-0.25.5.tgz",
"integrity": "sha512-PrikaNjiXdR2laW6OIjlbeuCPrPaAl0IwPIaRv+SMV8CiM8i2LqVUHFC1+8eORgWyY7yhQY+2U2fA55mBzReaw==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"freebsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-arm": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm/-/linux-arm-0.25.5.tgz",
"integrity": "sha512-cPzojwW2okgh7ZlRpcBEtsX7WBuqbLrNXqLU89GxWbNt6uIg78ET82qifUy3W6OVww6ZWobWub5oqZOVtwolfw==",
"cpu": [
"arm"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-arm64/-/linux-arm64-0.25.5.tgz",
"integrity": "sha512-Z9kfb1v6ZlGbWj8EJk9T6czVEjjq2ntSYLY2cw6pAZl4oKtfgQuS4HOq41M/BcoLPzrUbNd+R4BXFyH//nHxVg==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-ia32": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-ia32/-/linux-ia32-0.25.5.tgz",
"integrity": "sha512-sQ7l00M8bSv36GLV95BVAdhJ2QsIbCuCjh/uYrWiMQSUuV+LpXwIqhgJDcvMTj+VsQmqAHL2yYaasENvJ7CDKA==",
"cpu": [
"ia32"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-loong64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-loong64/-/linux-loong64-0.25.5.tgz",
"integrity": "sha512-0ur7ae16hDUC4OL5iEnDb0tZHDxYmuQyhKhsPBV8f99f6Z9KQM02g33f93rNH5A30agMS46u2HP6qTdEt6Q1kg==",
"cpu": [
"loong64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-mips64el": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-mips64el/-/linux-mips64el-0.25.5.tgz",
"integrity": "sha512-kB/66P1OsHO5zLz0i6X0RxlQ+3cu0mkxS3TKFvkb5lin6uwZ/ttOkP3Z8lfR9mJOBk14ZwZ9182SIIWFGNmqmg==",
"cpu": [
"mips64el"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-ppc64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-ppc64/-/linux-ppc64-0.25.5.tgz",
"integrity": "sha512-UZCmJ7r9X2fe2D6jBmkLBMQetXPXIsZjQJCjgwpVDz+YMcS6oFR27alkgGv3Oqkv07bxdvw7fyB71/olceJhkQ==",
"cpu": [
"ppc64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-riscv64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-riscv64/-/linux-riscv64-0.25.5.tgz",
"integrity": "sha512-kTxwu4mLyeOlsVIFPfQo+fQJAV9mh24xL+y+Bm6ej067sYANjyEw1dNHmvoqxJUCMnkBdKpvOn0Ahql6+4VyeA==",
"cpu": [
"riscv64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-s390x": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-s390x/-/linux-s390x-0.25.5.tgz",
"integrity": "sha512-K2dSKTKfmdh78uJ3NcWFiqyRrimfdinS5ErLSn3vluHNeHVnBAFWC8a4X5N+7FgVE1EjXS1QDZbpqZBjfrqMTQ==",
"cpu": [
"s390x"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/linux-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/linux-x64/-/linux-x64-0.25.5.tgz",
"integrity": "sha512-uhj8N2obKTE6pSZ+aMUbqq+1nXxNjZIIjCjGLfsWvVpy7gKCOL6rsY1MhRh9zLtUtAI7vpgLMK6DxjO8Qm9lJw==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"linux"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/netbsd-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-arm64/-/netbsd-arm64-0.25.5.tgz",
"integrity": "sha512-pwHtMP9viAy1oHPvgxtOv+OkduK5ugofNTVDilIzBLpoWAM16r7b/mxBvfpuQDpRQFMfuVr5aLcn4yveGvBZvw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"netbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/netbsd-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/netbsd-x64/-/netbsd-x64-0.25.5.tgz",
"integrity": "sha512-WOb5fKrvVTRMfWFNCroYWWklbnXH0Q5rZppjq0vQIdlsQKuw6mdSihwSo4RV/YdQ5UCKKvBy7/0ZZYLBZKIbwQ==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"netbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/openbsd-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-arm64/-/openbsd-arm64-0.25.5.tgz",
"integrity": "sha512-7A208+uQKgTxHd0G0uqZO8UjK2R0DDb4fDmERtARjSHWxqMTye4Erz4zZafx7Di9Cv+lNHYuncAkiGFySoD+Mw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"openbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/openbsd-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/openbsd-x64/-/openbsd-x64-0.25.5.tgz",
"integrity": "sha512-G4hE405ErTWraiZ8UiSoesH8DaCsMm0Cay4fsFWOOUcz8b8rC6uCvnagr+gnioEjWn0wC+o1/TAHt+It+MpIMg==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"openbsd"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/sunos-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/sunos-x64/-/sunos-x64-0.25.5.tgz",
"integrity": "sha512-l+azKShMy7FxzY0Rj4RCt5VD/q8mG/e+mDivgspo+yL8zW7qEwctQ6YqKX34DTEleFAvCIUviCFX1SDZRSyMQA==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"sunos"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-arm64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/win32-arm64/-/win32-arm64-0.25.5.tgz",
"integrity": "sha512-O2S7SNZzdcFG7eFKgvwUEZ2VG9D/sn/eIiz8XRZ1Q/DO5a3s76Xv0mdBzVM5j5R639lXQmPmSo0iRpHqUUrsxw==",
"cpu": [
"arm64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-ia32": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/win32-ia32/-/win32-ia32-0.25.5.tgz",
"integrity": "sha512-onOJ02pqs9h1iMJ1PQphR+VZv8qBMQ77Klcsqv9CNW2w6yLqoURLcgERAIurY6QE63bbLuqgP9ATqajFLK5AMQ==",
"cpu": [
"ia32"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@esbuild/win32-x64": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/@esbuild/win32-x64/-/win32-x64-0.25.5.tgz",
"integrity": "sha512-TXv6YnJ8ZMVdX+SXWVBo/0p8LTcrUYngpWjvm91TMjjBQii7Oz11Lw5lbDV5Y0TzuhSJHwiH4hEtC1I42mMS0g==",
"cpu": [
"x64"
],
"license": "MIT",
"optional": true,
"os": [
"win32"
],
"engines": {
"node": ">=18"
}
},
"node_modules/@eslint-community/eslint-utils": { "node_modules/@eslint-community/eslint-utils": {
"version": "4.7.0", "version": "4.7.0",
"resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz", "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz",
@ -5999,46 +5599,6 @@
"url": "https://github.com/sponsors/ljharb" "url": "https://github.com/sponsors/ljharb"
} }
}, },
"node_modules/esbuild": {
"version": "0.25.5",
"resolved": "https://registry.npmjs.org/esbuild/-/esbuild-0.25.5.tgz",
"integrity": "sha512-P8OtKZRv/5J5hhz0cUAdu/cLuPIKXpQl1R9pZtvmHWQvrAUVd0UNIPT4IB4W3rNOqVO0rlqHmCIbSwxh/c9yUQ==",
"hasInstallScript": true,
"license": "MIT",
"bin": {
"esbuild": "bin/esbuild"
},
"engines": {
"node": ">=18"
},
"optionalDependencies": {
"@esbuild/aix-ppc64": "0.25.5",
"@esbuild/android-arm": "0.25.5",
"@esbuild/android-arm64": "0.25.5",
"@esbuild/android-x64": "0.25.5",
"@esbuild/darwin-arm64": "0.25.5",
"@esbuild/darwin-x64": "0.25.5",
"@esbuild/freebsd-arm64": "0.25.5",
"@esbuild/freebsd-x64": "0.25.5",
"@esbuild/linux-arm": "0.25.5",
"@esbuild/linux-arm64": "0.25.5",
"@esbuild/linux-ia32": "0.25.5",
"@esbuild/linux-loong64": "0.25.5",
"@esbuild/linux-mips64el": "0.25.5",
"@esbuild/linux-ppc64": "0.25.5",
"@esbuild/linux-riscv64": "0.25.5",
"@esbuild/linux-s390x": "0.25.5",
"@esbuild/linux-x64": "0.25.5",
"@esbuild/netbsd-arm64": "0.25.5",
"@esbuild/netbsd-x64": "0.25.5",
"@esbuild/openbsd-arm64": "0.25.5",
"@esbuild/openbsd-x64": "0.25.5",
"@esbuild/sunos-x64": "0.25.5",
"@esbuild/win32-arm64": "0.25.5",
"@esbuild/win32-ia32": "0.25.5",
"@esbuild/win32-x64": "0.25.5"
}
},
"node_modules/escalade": { "node_modules/escalade": {
"version": "3.2.0", "version": "3.2.0",
"resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz", "resolved": "https://registry.npmjs.org/escalade/-/escalade-3.2.0.tgz",
@ -6993,6 +6553,7 @@
"version": "2.3.3", "version": "2.3.3",
"resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.3.tgz",
"integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==", "integrity": "sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==",
"dev": true,
"hasInstallScript": true, "hasInstallScript": true,
"license": "MIT", "license": "MIT",
"optional": true, "optional": true,
@ -7154,6 +6715,7 @@
"version": "4.10.0", "version": "4.10.0",
"resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz", "resolved": "https://registry.npmjs.org/get-tsconfig/-/get-tsconfig-4.10.0.tgz",
"integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==", "integrity": "sha512-kGzZ3LWWQcGIAmg6iWvXn0ei6WDtV26wzHRMwDSzmAbcXrTEXxHy6IehI6/4eT6VRKyMP1eF1VqwrVUmE/LR7A==",
"dev": true,
"license": "MIT", "license": "MIT",
"dependencies": { "dependencies": {
"resolve-pkg-maps": "^1.0.0" "resolve-pkg-maps": "^1.0.0"
@ -9537,9 +9099,10 @@
"license": "MIT" "license": "MIT"
}, },
"node_modules/llama-stack-client": { "node_modules/llama-stack-client": {
"version": "0.2.13", "version": "0.2.14",
"resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.13.tgz", "resolved": "https://registry.npmjs.org/llama-stack-client/-/llama-stack-client-0.2.14.tgz",
"integrity": "sha512-R1rTFLwgUimr+KjEUkzUvFL6vLASwS9qj3UDSVkJ5BmrKAs5GwVAMeL7yZaTBXGuPUVh124WSlC4d9H0FjWqLA==", "integrity": "sha512-bVU3JHp+EPEKR0Vb9vcd9ZyQj/72jSDuptKLwOXET9WrkphIQ8xuW5ueecMTgq8UEls3lwB3HiZM2cDOR9eDsQ==",
"license": "Apache-2.0",
"dependencies": { "dependencies": {
"@types/node": "^18.11.18", "@types/node": "^18.11.18",
"@types/node-fetch": "^2.6.4", "@types/node-fetch": "^2.6.4",
@ -9547,8 +9110,7 @@
"agentkeepalive": "^4.2.1", "agentkeepalive": "^4.2.1",
"form-data-encoder": "1.7.2", "form-data-encoder": "1.7.2",
"formdata-node": "^4.3.2", "formdata-node": "^4.3.2",
"node-fetch": "^2.6.7", "node-fetch": "^2.6.7"
"tsx": "^4.19.2"
} }
}, },
"node_modules/llama-stack-client/node_modules/@types/node": { "node_modules/llama-stack-client/node_modules/@types/node": {
@ -11148,6 +10710,7 @@
"version": "1.0.0", "version": "1.0.0",
"resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz", "resolved": "https://registry.npmjs.org/resolve-pkg-maps/-/resolve-pkg-maps-1.0.0.tgz",
"integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==", "integrity": "sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==",
"dev": true,
"license": "MIT", "license": "MIT",
"funding": { "funding": {
"url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1"
@ -12198,25 +11761,6 @@
"integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==", "integrity": "sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==",
"license": "0BSD" "license": "0BSD"
}, },
"node_modules/tsx": {
"version": "4.19.4",
"resolved": "https://registry.npmjs.org/tsx/-/tsx-4.19.4.tgz",
"integrity": "sha512-gK5GVzDkJK1SI1zwHf32Mqxf2tSJkNx+eYcNly5+nHvWqXUJYUkWBQtKauoESz3ymezAI++ZwT855x5p5eop+Q==",
"license": "MIT",
"dependencies": {
"esbuild": "~0.25.0",
"get-tsconfig": "^4.7.5"
},
"bin": {
"tsx": "dist/cli.mjs"
},
"engines": {
"node": ">=18.0.0"
},
"optionalDependencies": {
"fsevents": "~2.3.3"
}
},
"node_modules/tw-animate-css": { "node_modules/tw-animate-css": {
"version": "1.2.9", "version": "1.2.9",
"resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz", "resolved": "https://registry.npmjs.org/tw-animate-css/-/tw-animate-css-1.2.9.tgz",

View file

@ -20,7 +20,7 @@
"@radix-ui/react-tooltip": "^1.2.6", "@radix-ui/react-tooltip": "^1.2.6",
"class-variance-authority": "^0.7.1", "class-variance-authority": "^0.7.1",
"clsx": "^2.1.1", "clsx": "^2.1.1",
"llama-stack-client": "0.2.13", "llama-stack-client": "^0.2.14",
"lucide-react": "^0.510.0", "lucide-react": "^0.510.0",
"next": "15.3.3", "next": "15.3.3",
"next-auth": "^4.24.11", "next-auth": "^4.24.11",

View file

@ -32,7 +32,7 @@ dependencies = [
"openai>=1.66", "openai>=1.66",
"prompt-toolkit", "prompt-toolkit",
"python-dotenv", "python-dotenv",
"python-jose", "python-jose[cryptography]",
"pydantic>=2", "pydantic>=2",
"rich", "rich",
"starlette", "starlette",
@ -42,8 +42,8 @@ dependencies = [
"h11>=0.16.0", "h11>=0.16.0",
"python-multipart>=0.0.20", # For fastapi Form "python-multipart>=0.0.20", # For fastapi Form
"uvicorn>=0.34.0", # server "uvicorn>=0.34.0", # server
"opentelemetry-sdk", # server "opentelemetry-sdk>=1.30.0", # server
"opentelemetry-exporter-otlp-proto-http", # server "opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
"aiosqlite>=0.21.0", # server - for metadata store "aiosqlite>=0.21.0", # server - for metadata store
"asyncpg", # for metadata store "asyncpg", # for metadata store
] ]
@ -58,12 +58,13 @@ ui = [
[dependency-groups] [dependency-groups]
dev = [ dev = [
"pytest", "pytest>=8.4",
"pytest-timeout", "pytest-timeout",
"pytest-asyncio", "pytest-asyncio>=1.0",
"pytest-cov", "pytest-cov",
"pytest-html", "pytest-html",
"pytest-json-report", "pytest-json-report",
"pytest-socket", # For blocking network access in unit tests
"nbval", # For notebook testing "nbval", # For notebook testing
"black", "black",
"ruff", "ruff",
@ -87,6 +88,8 @@ unit = [
"blobfile", "blobfile",
"faiss-cpu", "faiss-cpu",
"pymilvus>=2.5.12", "pymilvus>=2.5.12",
"litellm",
"together",
] ]
# These are the core dependencies required for running integration tests. They are shared across all # These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment # providers. If a provider requires additional dependencies, please add them to your environment
@ -226,7 +229,6 @@ follow_imports = "silent"
exclude = [ exclude = [
# As we fix more and more of these, we should remove them from the list # As we fix more and more of these, we should remove them from the list
"^llama_stack/cli/download\\.py$", "^llama_stack/cli/download\\.py$",
"^llama_stack/cli/stack/_build\\.py$",
"^llama_stack/distribution/build\\.py$", "^llama_stack/distribution/build\\.py$",
"^llama_stack/distribution/client\\.py$", "^llama_stack/distribution/client\\.py$",
"^llama_stack/distribution/request_headers\\.py$", "^llama_stack/distribution/request_headers\\.py$",
@ -256,7 +258,6 @@ exclude = [
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/inference/vllm/",
"^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/post_training/common/validator\\.py$",
"^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$",
"^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/code_scanner/",
"^llama_stack/providers/inline/safety/llama_guard/", "^llama_stack/providers/inline/safety/llama_guard/",
"^llama_stack/providers/inline/safety/prompt_guard/", "^llama_stack/providers/inline/safety/prompt_guard/",
@ -341,3 +342,9 @@ warn_required_dynamic_aliases = true
[tool.ruff.lint.pep8-naming] [tool.ruff.lint.pep8-naming]
classmethod-decorators = ["classmethod", "pydantic.field_validator"] classmethod-decorators = ["classmethod", "pydantic.field_validator"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
markers = [
"allow_network: Allow network access for specific unit tests",
]

View file

@ -28,6 +28,8 @@ certifi==2025.1.31
# httpcore # httpcore
# httpx # httpx
# requests # requests
cffi==1.17.1 ; platform_python_implementation != 'PyPy'
# via cryptography
charset-normalizer==3.4.1 charset-normalizer==3.4.1
# via requests # via requests
click==8.1.8 click==8.1.8
@ -38,6 +40,8 @@ colorama==0.4.6 ; sys_platform == 'win32'
# via # via
# click # click
# tqdm # tqdm
cryptography==45.0.5
# via python-jose
deprecated==1.2.18 deprecated==1.2.18
# via # via
# opentelemetry-api # opentelemetry-api
@ -156,6 +160,8 @@ pyasn1==0.4.8
# via # via
# python-jose # python-jose
# rsa # rsa
pycparser==2.22 ; platform_python_implementation != 'PyPy'
# via cffi
pydantic==2.10.6 pydantic==2.10.6
# via # via
# fastapi # fastapi

View file

@ -16,4 +16,4 @@ if [ $FOUND_PYTHON -ne 0 ]; then
uv python install "$PYTHON_VERSION" uv python install "$PYTHON_VERSION"
fi fi
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest --asyncio-mode=auto -s -v tests/unit/ $@ uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@

View file

@ -7,7 +7,8 @@ FROM --platform=linux/amd64 ollama/ollama:latest
RUN ollama serve & \ RUN ollama serve & \
sleep 5 && \ sleep 5 && \
ollama pull llama3.2:3b-instruct-fp16 && \ ollama pull llama3.2:3b-instruct-fp16 && \
ollama pull all-minilm:l6-v2 ollama pull all-minilm:l6-v2 && \
ollama pull llama-guard3:1b
# Set the entrypoint to start ollama serve # Set the entrypoint to start ollama serve
ENTRYPOINT ["ollama", "serve"] ENTRYPOINT ["ollama", "serve"]

View file

@ -44,7 +44,6 @@ def common_params(inference_model):
) )
@pytest.mark.asyncio
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world") @pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
async def test_delete_agents_and_sessions(self, agents_stack, common_params): async def test_delete_agents_and_sessions(self, agents_stack, common_params):
agents_impl = agents_stack.impls[Api.agents] agents_impl = agents_stack.impls[Api.agents]
@ -73,7 +72,6 @@ async def test_delete_agents_and_sessions(self, agents_stack, common_params):
assert agent_response is None assert agent_response is None
@pytest.mark.asyncio
@pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world") @pytest.mark.skip(reason="This test needs to be migrated to api / client-sdk world")
async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params): async def test_get_agent_turns_and_steps(self, agents_stack, sample_messages, common_params):
agents_impl = agents_stack.impls[Api.agents] agents_impl = agents_stack.impls[Api.agents]

View file

@ -6,6 +6,7 @@
import inspect import inspect
import os import os
import signal
import socket import socket
import subprocess import subprocess
import tempfile import tempfile
@ -45,6 +46,8 @@ def start_llama_stack_server(config_name: str) -> subprocess.Popen:
stderr=subprocess.PIPE, # keep stderr to see errors stderr=subprocess.PIPE, # keep stderr to see errors
text=True, text=True,
env={**os.environ, "LLAMA_STACK_LOG_FILE": "server.log"}, env={**os.environ, "LLAMA_STACK_LOG_FILE": "server.log"},
# Create new process group so we can kill all child processes
preexec_fn=os.setsid,
) )
return process return process
@ -197,7 +200,7 @@ def llama_stack_client(request, provider_data):
server_process = start_llama_stack_server(config_name) server_process = start_llama_stack_server(config_name)
# Wait for server to be ready # Wait for server to be ready
if not wait_for_server_ready(base_url, timeout=30, process=server_process): if not wait_for_server_ready(base_url, timeout=120, process=server_process):
print("Server failed to start within timeout") print("Server failed to start within timeout")
server_process.terminate() server_process.terminate()
raise RuntimeError( raise RuntimeError(
@ -215,6 +218,7 @@ def llama_stack_client(request, provider_data):
return LlamaStackClient( return LlamaStackClient(
base_url=base_url, base_url=base_url,
provider_data=provider_data, provider_data=provider_data,
timeout=int(os.environ.get("LLAMA_STACK_CLIENT_TIMEOUT", "30")),
) )
# check if this looks like a URL using proper URL parsing # check if this looks like a URL using proper URL parsing
@ -267,14 +271,17 @@ def cleanup_server_process(request):
print(f"Server process already terminated with return code: {server_process.returncode}") print(f"Server process already terminated with return code: {server_process.returncode}")
return return
try: try:
server_process.terminate() print(f"Terminating process {server_process.pid} and its group...")
# Kill the entire process group
os.killpg(os.getpgid(server_process.pid), signal.SIGTERM)
server_process.wait(timeout=10) server_process.wait(timeout=10)
print("Server process terminated gracefully") print("Server process and children terminated gracefully")
except subprocess.TimeoutExpired: except subprocess.TimeoutExpired:
print("Server process did not terminate gracefully, killing it") print("Server process did not terminate gracefully, killing it")
server_process.kill() # Force kill the entire process group
os.killpg(os.getpgid(server_process.pid), signal.SIGKILL)
server_process.wait() server_process.wait()
print("Server process killed") print("Server process and children killed")
except Exception as e: except Exception as e:
print(f"Error during server cleanup: {e}") print(f"Error during server cleanup: {e}")
else: else:

View file

@ -71,7 +71,6 @@ def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, mode
"remote::cerebras", "remote::cerebras",
"remote::databricks", "remote::databricks",
"remote::runpod", "remote::runpod",
"remote::sambanova",
"remote::tgi", "remote::tgi",
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI chat completions.")

View file

@ -4,20 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
class TestInspect: class TestInspect:
@pytest.mark.asyncio
def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): def test_health(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
health = llama_stack_client.inspect.health() health = llama_stack_client.inspect.health()
assert health is not None assert health is not None
assert health.status == "OK" assert health.status == "OK"
@pytest.mark.asyncio
def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): def test_version(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
version = llama_stack_client.inspect.version() version = llama_stack_client.inspect.version()
assert version is not None assert version is not None

View file

@ -4,14 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack_client import LlamaStackClient from llama_stack_client import LlamaStackClient
from llama_stack import LlamaStackAsLibraryClient from llama_stack import LlamaStackAsLibraryClient
class TestProviders: class TestProviders:
@pytest.mark.asyncio
def test_providers(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient): def test_providers(self, llama_stack_client: LlamaStackAsLibraryClient | LlamaStackClient):
provider_list = llama_stack_client.providers.list() provider_list = llama_stack_client.providers.list()
assert provider_list is not None assert provider_list is not None

View file

@ -14,8 +14,7 @@ from llama_stack.distribution.access_control.access_control import default_polic
from llama_stack.distribution.datatypes import User from llama_stack.distribution.datatypes import User
from llama_stack.providers.utils.sqlstore.api import ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlalchemy_sqlstore import SqlAlchemySqlStoreImpl from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig
def get_postgres_config(): def get_postgres_config():
@ -30,144 +29,211 @@ def get_postgres_config():
def get_sqlite_config(): def get_sqlite_config():
"""Get SQLite configuration with temporary database.""" """Get SQLite configuration with temporary file database."""
tmp_file = tempfile.NamedTemporaryFile(suffix=".db", delete=False) temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
tmp_file.close() temp_file.close()
return SqliteSqlStoreConfig(db_path=tmp_file.name), tmp_file.name return SqliteSqlStoreConfig(db_path=temp_file.name)
@pytest.mark.asyncio # Backend configurations for parametrized tests
@pytest.mark.parametrize( BACKEND_CONFIGS = [
"backend_config", pytest.param(
[ get_postgres_config,
pytest.param( marks=pytest.mark.skipif(
("postgres", get_postgres_config), not os.environ.get("ENABLE_POSTGRES_TESTS"),
marks=pytest.mark.skipif( reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
not os.environ.get("ENABLE_POSTGRES_TESTS"),
reason="PostgreSQL tests require ENABLE_POSTGRES_TESTS environment variable",
),
id="postgres",
), ),
pytest.param(("sqlite", get_sqlite_config), id="sqlite"), id="postgres",
], ),
) pytest.param(get_sqlite_config, id="sqlite"),
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") ]
async def test_json_comparison(mock_get_authenticated_user, backend_config):
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
backend_name, config_func = backend_config
# Handle different config types
if backend_name == "postgres": @pytest.fixture
config = config_func() def authorized_store(backend_config):
cleanup_path = None """Set up authorized store with proper cleanup."""
else: # sqlite config_func = backend_config
config, cleanup_path = config_func()
config = config_func()
base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
yield authorized_store
if hasattr(config, "db_path"):
try:
os.unlink(config.db_path)
except (OSError, FileNotFoundError):
pass
async def create_test_table(authorized_store, table_name):
"""Create a test table with standard schema."""
await authorized_store.create_table(
table=table_name,
schema={
"id": ColumnType.STRING,
"data": ColumnType.STRING,
},
)
async def cleanup_records(sql_store, table_name, record_ids):
"""Clean up test records."""
for record_id in record_ids:
try:
await sql_store.delete(table_name, {"id": record_id})
except Exception:
pass
@pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
async def test_authorized_store_attributes(mock_get_authenticated_user, authorized_store, request):
"""Test that JSON column comparisons work correctly for both PostgreSQL and SQLite"""
backend_name = request.node.callspec.id
# Create test table
table_name = f"test_json_comparison_{backend_name}"
await create_test_table(authorized_store, table_name)
try: try:
base_sqlstore = SqlAlchemySqlStoreImpl(config) # Test with no authenticated user (should handle JSON null comparison)
authorized_store = AuthorizedSqlStore(base_sqlstore) mock_get_authenticated_user.return_value = None
# Create test table # Insert some test data
table_name = f"test_json_comparison_{backend_name}" await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
await authorized_store.create_table(
table=table_name, # Test fetching with no user - should not error on JSON comparison
schema={ result = await authorized_store.fetch_all(table_name, policy=default_policy())
"id": ColumnType.STRING, assert len(result.data) == 1
"data": ColumnType.STRING, assert result.data[0]["id"] == "1"
}, assert result.data[0]["access_attributes"] is None
# Test with authenticated user
test_user = User("test-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = test_user
# Insert data with user attributes
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 2
# Test with non-admin user
regular_user = User("regular-user", {"roles": ["user"]})
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
# Test the category missing branch: user with multiple attributes
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]})
mock_get_authenticated_user.return_value = multi_user
# Insert record with multi-user (has both roles and teams)
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"})
# Test different user types to create records with different attribute patterns
# Record with only roles (teams category will be missing)
roles_only_user = User("roles-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = roles_only_user
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"})
# Record with only teams (roles category will be missing)
teams_only_user = User("teams-user", {"teams": ["dev"]})
mock_get_authenticated_user.return_value = teams_only_user
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"})
# Record with different roles/teams (shouldn't match our test user)
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]})
mock_get_authenticated_user.return_value = different_user
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"})
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
# Should see:
# - public record (1) - no access_attributes
# - admin record (2) - user matches roles=admin, teams missing (allowed)
# - multi_user record (3) - user matches both roles=admin and teams=dev
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
# Should NOT see:
# - different_user record (6) - user doesn't match roles=user or teams=qa
expected_ids = {"1", "2", "3", "4", "5"}
actual_ids = {record["id"] for record in result.data}
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
# Verify the category missing logic specifically
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
assert category_test_ids == {"4", "5"}, (
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
) )
try: finally:
# Test with no authenticated user (should handle JSON null comparison) # Clean up records
mock_get_authenticated_user.return_value = None await cleanup_records(authorized_store.sql_store, table_name, ["1", "2", "3", "4", "5", "6"])
# Insert some test data
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison @pytest.mark.parametrize("backend_config", BACKEND_CONFIGS)
result = await authorized_store.fetch_all(table_name, policy=default_policy()) @patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user")
assert len(result.data) == 1 async def test_user_ownership_policy(mock_get_authenticated_user, authorized_store, request):
assert result.data[0]["id"] == "1" """Test that 'user is owner' policies work correctly with record ownership"""
assert result.data[0]["access_attributes"] is None from llama_stack.distribution.access_control.datatypes import AccessRule, Action, Scope
# Test with authenticated user backend_name = request.node.callspec.id
test_user = User("test-user", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = test_user
# Insert data with user attributes # Create test table
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) table_name = f"test_ownership_{backend_name}"
await create_test_table(authorized_store, table_name)
# Fetch all - admin should see both try:
result = await authorized_store.fetch_all(table_name, policy=default_policy()) # Test with first user who creates records
assert len(result.data) == 2 user1 = User("user1", {"roles": ["admin"]})
mock_get_authenticated_user.return_value = user1
# Test with non-admin user # Insert a record owned by user1
regular_user = User("regular-user", {"roles": ["user"]}) await authorized_store.insert(table_name, {"id": "1", "data": "user1_data"})
mock_get_authenticated_user.return_value = regular_user
# Should only see public record # Test with second user
result = await authorized_store.fetch_all(table_name, policy=default_policy()) user2 = User("user2", {"roles": ["user"]})
assert len(result.data) == 1 mock_get_authenticated_user.return_value = user2
assert result.data[0]["id"] == "1"
# Test the category missing branch: user with multiple attributes # Insert a record owned by user2
multi_user = User("multi-user", {"roles": ["admin"], "teams": ["dev"]}) await authorized_store.insert(table_name, {"id": "2", "data": "user2_data"})
mock_get_authenticated_user.return_value = multi_user
# Insert record with multi-user (has both roles and teams) # Create a policy that only allows access when user is the owner
await authorized_store.insert(table_name, {"id": "3", "data": "multi_user_data"}) owner_only_policy = [
AccessRule(
permit=Scope(actions=[Action.READ]),
when=["user is owner"],
),
]
# Test different user types to create records with different attribute patterns # Test user1 access - should only see their own record
# Record with only roles (teams category will be missing) mock_get_authenticated_user.return_value = user1
roles_only_user = User("roles-user", {"roles": ["admin"]}) result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
mock_get_authenticated_user.return_value = roles_only_user assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
await authorized_store.insert(table_name, {"id": "4", "data": "roles_only_data"}) assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Record with only teams (roles category will be missing) # Test user2 access - should only see their own record
teams_only_user = User("teams-user", {"teams": ["dev"]}) mock_get_authenticated_user.return_value = user2
mock_get_authenticated_user.return_value = teams_only_user result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
await authorized_store.insert(table_name, {"id": "5", "data": "teams_only_data"}) assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Record with different roles/teams (shouldn't match our test user) # Test with anonymous user - should see no records
different_user = User("different-user", {"roles": ["user"], "teams": ["qa"]}) mock_get_authenticated_user.return_value = None
mock_get_authenticated_user.return_value = different_user result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
await authorized_store.insert(table_name, {"id": "6", "data": "different_user_data"}) assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
# Should see:
# - public record (1) - no access_attributes
# - admin record (2) - user matches roles=admin, teams missing (allowed)
# - multi_user record (3) - user matches both roles=admin and teams=dev
# - roles_only record (4) - user matches roles=admin, teams missing (allowed)
# - teams_only record (5) - user matches teams=dev, roles missing (allowed)
# Should NOT see:
# - different_user record (6) - user doesn't match roles=user or teams=qa
expected_ids = {"1", "2", "3", "4", "5"}
actual_ids = {record["id"] for record in result.data}
assert actual_ids == expected_ids, f"Expected to see records {expected_ids} but got {actual_ids}"
# Verify the category missing logic specifically
# Records 4 and 5 test the "category missing" branch where one attribute category is missing
category_test_ids = {record["id"] for record in result.data if record["id"] in ["4", "5"]}
assert category_test_ids == {"4", "5"}, (
f"Category missing logic failed: expected 4,5 but got {category_test_ids}"
)
finally:
# Clean up records
for record_id in ["1", "2", "3", "4", "5", "6"]:
try:
await base_sqlstore.delete(table_name, {"id": record_id})
except Exception:
pass
finally: finally:
# Clean up temporary SQLite database file if needed # Clean up records
if cleanup_path: await cleanup_records(authorized_store.sql_store, table_name, ["1", "2"])
try:
os.unlink(cleanup_path)
except OSError:
pass

View file

@ -4,6 +4,17 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest_socket
# We need to import the fixtures here so that pytest can find them # We need to import the fixtures here so that pytest can find them
# but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed # but ruff doesn't think they are used and removes the import. "noqa: F401" prevents them from being removed
from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401 from .fixtures import cached_disk_dist_registry, disk_dist_registry, sqlite_kvstore # noqa: F401
def pytest_runtest_setup(item):
"""Setup for each test - check if network access should be allowed."""
if "allow_network" in item.keywords:
pytest_socket.enable_socket()
else:
# Allowing Unix sockets is necessary for some tests that use local servers and mocks
pytest_socket.disable_socket(allow_unix_socket=True)

View file

@ -8,8 +8,6 @@
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest
from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.common.type_system import NumberType
from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.datatypes import Api from llama_stack.apis.datatypes import Api
@ -119,7 +117,6 @@ class ToolGroupsImpl(Impl):
) )
@pytest.mark.asyncio
async def test_models_routing_table(cached_disk_dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -161,7 +158,6 @@ async def test_models_routing_table(cached_disk_dist_registry):
assert len(openai_models.data) == 0 assert len(openai_models.data) == 0
@pytest.mark.asyncio
async def test_shields_routing_table(cached_disk_dist_registry): async def test_shields_routing_table(cached_disk_dist_registry):
table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {}) table = ShieldsRoutingTable({"test_provider": SafetyImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -177,7 +173,6 @@ async def test_shields_routing_table(cached_disk_dist_registry):
assert "test-shield-2" in shield_ids assert "test-shield-2" in shield_ids
@pytest.mark.asyncio
async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -233,7 +228,6 @@ async def test_datasets_routing_table(cached_disk_dist_registry):
assert len(datasets.data) == 0 assert len(datasets.data) == 0
@pytest.mark.asyncio
async def test_scoring_functions_routing_table(cached_disk_dist_registry): async def test_scoring_functions_routing_table(cached_disk_dist_registry):
table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {}) table = ScoringFunctionsRoutingTable({"test_provider": ScoringFunctionsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -259,7 +253,6 @@ async def test_scoring_functions_routing_table(cached_disk_dist_registry):
assert "test-scoring-fn-2" in scoring_fn_ids assert "test-scoring-fn-2" in scoring_fn_ids
@pytest.mark.asyncio
async def test_benchmarks_routing_table(cached_disk_dist_registry): async def test_benchmarks_routing_table(cached_disk_dist_registry):
table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {}) table = BenchmarksRoutingTable({"test_provider": BenchmarksImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -277,7 +270,6 @@ async def test_benchmarks_routing_table(cached_disk_dist_registry):
assert "test-benchmark" in benchmark_ids assert "test-benchmark" in benchmark_ids
@pytest.mark.asyncio
async def test_tool_groups_routing_table(cached_disk_dist_registry): async def test_tool_groups_routing_table(cached_disk_dist_registry):
table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {}) table = ToolGroupsRoutingTable({"test_provider": ToolGroupsImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()

View file

@ -13,7 +13,6 @@ import pytest
from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.distribution.utils.context import preserve_contexts_async_generator
@pytest.mark.asyncio
async def test_preserve_contexts_with_exception(): async def test_preserve_contexts_with_exception():
# Create context variable # Create context variable
context_var = ContextVar("exception_var", default="initial") context_var = ContextVar("exception_var", default="initial")
@ -41,7 +40,6 @@ async def test_preserve_contexts_with_exception():
context_var.reset(token) context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_empty_generator(): async def test_preserve_contexts_empty_generator():
# Create context variable # Create context variable
context_var = ContextVar("empty_var", default="initial") context_var = ContextVar("empty_var", default="initial")
@ -66,7 +64,6 @@ async def test_preserve_contexts_empty_generator():
context_var.reset(token) context_var.reset(token)
@pytest.mark.asyncio
async def test_preserve_contexts_across_event_loops(): async def test_preserve_contexts_across_event_loops():
""" """
Test that context variables are preserved across event loop boundaries with nested generators. Test that context variables are preserved across event loop boundaries with nested generators.

View file

@ -6,7 +6,6 @@
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.common.responses import Order from llama_stack.apis.common.responses import Order
from llama_stack.apis.files import OpenAIFilePurpose from llama_stack.apis.files import OpenAIFilePurpose
@ -29,7 +28,7 @@ class MockUploadFile:
return self.content return self.content
@pytest_asyncio.fixture @pytest.fixture
async def files_provider(tmp_path): async def files_provider(tmp_path):
"""Create a files provider with temporary storage for testing.""" """Create a files provider with temporary storage for testing."""
storage_dir = tmp_path / "files" storage_dir = tmp_path / "files"
@ -68,7 +67,6 @@ def large_file():
class TestOpenAIFilesAPI: class TestOpenAIFilesAPI:
"""Test suite for OpenAI Files API endpoints.""" """Test suite for OpenAI Files API endpoints."""
@pytest.mark.asyncio
async def test_upload_file_success(self, files_provider, sample_text_file): async def test_upload_file_success(self, files_provider, sample_text_file):
"""Test successful file upload.""" """Test successful file upload."""
# Upload file # Upload file
@ -82,7 +80,6 @@ class TestOpenAIFilesAPI:
assert result.created_at > 0 assert result.created_at > 0
assert result.expires_at > result.created_at assert result.expires_at > result.created_at
@pytest.mark.asyncio
async def test_upload_different_purposes(self, files_provider, sample_text_file): async def test_upload_different_purposes(self, files_provider, sample_text_file):
"""Test uploading files with different purposes.""" """Test uploading files with different purposes."""
purposes = list(OpenAIFilePurpose) purposes = list(OpenAIFilePurpose)
@ -93,7 +90,6 @@ class TestOpenAIFilesAPI:
uploaded_files.append(result) uploaded_files.append(result)
assert result.purpose == purpose assert result.purpose == purpose
@pytest.mark.asyncio
async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file): async def test_upload_different_file_types(self, files_provider, sample_text_file, sample_json_file, large_file):
"""Test uploading different types and sizes of files.""" """Test uploading different types and sizes of files."""
files_to_test = [ files_to_test = [
@ -107,7 +103,6 @@ class TestOpenAIFilesAPI:
assert result.filename == expected_filename assert result.filename == expected_filename
assert result.bytes == len(file_obj.content) assert result.bytes == len(file_obj.content)
@pytest.mark.asyncio
async def test_list_files_empty(self, files_provider): async def test_list_files_empty(self, files_provider):
"""Test listing files when no files exist.""" """Test listing files when no files exist."""
result = await files_provider.openai_list_files() result = await files_provider.openai_list_files()
@ -117,7 +112,6 @@ class TestOpenAIFilesAPI:
assert result.first_id == "" assert result.first_id == ""
assert result.last_id == "" assert result.last_id == ""
@pytest.mark.asyncio
async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file): async def test_list_files_with_content(self, files_provider, sample_text_file, sample_json_file):
"""Test listing files when files exist.""" """Test listing files when files exist."""
# Upload multiple files # Upload multiple files
@ -132,7 +126,6 @@ class TestOpenAIFilesAPI:
assert file1.id in file_ids assert file1.id in file_ids
assert file2.id in file_ids assert file2.id in file_ids
@pytest.mark.asyncio
async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file): async def test_list_files_with_purpose_filter(self, files_provider, sample_text_file):
"""Test listing files with purpose filtering.""" """Test listing files with purpose filtering."""
# Upload file with specific purpose # Upload file with specific purpose
@ -146,7 +139,6 @@ class TestOpenAIFilesAPI:
assert result.data[0].id == uploaded_file.id assert result.data[0].id == uploaded_file.id
assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS assert result.data[0].purpose == OpenAIFilePurpose.ASSISTANTS
@pytest.mark.asyncio
async def test_list_files_with_limit(self, files_provider, sample_text_file): async def test_list_files_with_limit(self, files_provider, sample_text_file):
"""Test listing files with limit parameter.""" """Test listing files with limit parameter."""
# Upload multiple files # Upload multiple files
@ -157,7 +149,6 @@ class TestOpenAIFilesAPI:
result = await files_provider.openai_list_files(limit=3) result = await files_provider.openai_list_files(limit=3)
assert len(result.data) == 3 assert len(result.data) == 3
@pytest.mark.asyncio
async def test_list_files_with_order(self, files_provider, sample_text_file): async def test_list_files_with_order(self, files_provider, sample_text_file):
"""Test listing files with different order.""" """Test listing files with different order."""
# Upload multiple files # Upload multiple files
@ -178,7 +169,6 @@ class TestOpenAIFilesAPI:
# Oldest should be first # Oldest should be first
assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at assert result_asc.data[0].created_at <= result_asc.data[1].created_at <= result_asc.data[2].created_at
@pytest.mark.asyncio
async def test_retrieve_file_success(self, files_provider, sample_text_file): async def test_retrieve_file_success(self, files_provider, sample_text_file):
"""Test successful file retrieval.""" """Test successful file retrieval."""
# Upload file # Upload file
@ -197,13 +187,11 @@ class TestOpenAIFilesAPI:
assert retrieved_file.created_at == uploaded_file.created_at assert retrieved_file.created_at == uploaded_file.created_at
assert retrieved_file.expires_at == uploaded_file.expires_at assert retrieved_file.expires_at == uploaded_file.expires_at
@pytest.mark.asyncio
async def test_retrieve_file_not_found(self, files_provider): async def test_retrieve_file_not_found(self, files_provider):
"""Test retrieving a non-existent file.""" """Test retrieving a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"): with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file("file-nonexistent") await files_provider.openai_retrieve_file("file-nonexistent")
@pytest.mark.asyncio
async def test_retrieve_file_content_success(self, files_provider, sample_text_file): async def test_retrieve_file_content_success(self, files_provider, sample_text_file):
"""Test successful file content retrieval.""" """Test successful file content retrieval."""
# Upload file # Upload file
@ -217,13 +205,11 @@ class TestOpenAIFilesAPI:
# Verify content # Verify content
assert content.body == sample_text_file.content assert content.body == sample_text_file.content
@pytest.mark.asyncio
async def test_retrieve_file_content_not_found(self, files_provider): async def test_retrieve_file_content_not_found(self, files_provider):
"""Test retrieving content of a non-existent file.""" """Test retrieving content of a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"): with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_retrieve_file_content("file-nonexistent") await files_provider.openai_retrieve_file_content("file-nonexistent")
@pytest.mark.asyncio
async def test_delete_file_success(self, files_provider, sample_text_file): async def test_delete_file_success(self, files_provider, sample_text_file):
"""Test successful file deletion.""" """Test successful file deletion."""
# Upload file # Upload file
@ -245,13 +231,11 @@ class TestOpenAIFilesAPI:
with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"): with pytest.raises(ValueError, match=f"File with id {uploaded_file.id} not found"):
await files_provider.openai_retrieve_file(uploaded_file.id) await files_provider.openai_retrieve_file(uploaded_file.id)
@pytest.mark.asyncio
async def test_delete_file_not_found(self, files_provider): async def test_delete_file_not_found(self, files_provider):
"""Test deleting a non-existent file.""" """Test deleting a non-existent file."""
with pytest.raises(ValueError, match="File with id file-nonexistent not found"): with pytest.raises(ValueError, match="File with id file-nonexistent not found"):
await files_provider.openai_delete_file("file-nonexistent") await files_provider.openai_delete_file("file-nonexistent")
@pytest.mark.asyncio
async def test_file_persistence_across_operations(self, files_provider, sample_text_file): async def test_file_persistence_across_operations(self, files_provider, sample_text_file):
"""Test that files persist correctly across multiple operations.""" """Test that files persist correctly across multiple operations."""
# Upload file # Upload file
@ -279,7 +263,6 @@ class TestOpenAIFilesAPI:
files_list = await files_provider.openai_list_files() files_list = await files_provider.openai_list_files()
assert len(files_list.data) == 0 assert len(files_list.data) == 0
@pytest.mark.asyncio
async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file): async def test_multiple_files_operations(self, files_provider, sample_text_file, sample_json_file):
"""Test operations with multiple files.""" """Test operations with multiple files."""
# Upload multiple files # Upload multiple files
@ -302,7 +285,6 @@ class TestOpenAIFilesAPI:
content = await files_provider.openai_retrieve_file_content(file2.id) content = await files_provider.openai_retrieve_file_content(file2.id)
assert content.body == sample_json_file.content assert content.body == sample_json_file.content
@pytest.mark.asyncio
async def test_file_id_uniqueness(self, files_provider, sample_text_file): async def test_file_id_uniqueness(self, files_provider, sample_text_file):
"""Test that each uploaded file gets a unique ID.""" """Test that each uploaded file gets a unique ID."""
file_ids = set() file_ids = set()
@ -316,7 +298,6 @@ class TestOpenAIFilesAPI:
file_ids.add(uploaded_file.id) file_ids.add(uploaded_file.id)
assert uploaded_file.id.startswith("file-") assert uploaded_file.id.startswith("file-")
@pytest.mark.asyncio
async def test_file_no_filename_handling(self, files_provider): async def test_file_no_filename_handling(self, files_provider):
"""Test handling files with no filename.""" """Test handling files with no filename."""
file_without_name = MockUploadFile(b"content", None) # No filename file_without_name = MockUploadFile(b"content", None) # No filename
@ -327,7 +308,6 @@ class TestOpenAIFilesAPI:
assert uploaded_file.filename == "uploaded_file" # Default filename assert uploaded_file.filename == "uploaded_file" # Default filename
@pytest.mark.asyncio
async def test_after_pagination_works(self, files_provider, sample_text_file): async def test_after_pagination_works(self, files_provider, sample_text_file):
"""Test that 'after' pagination works correctly.""" """Test that 'after' pagination works correctly."""
# Upload multiple files to test pagination # Upload multiple files to test pagination

View file

@ -4,14 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest_asyncio import pytest
from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry from llama_stack.distribution.store.registry import CachedDiskDistributionRegistry, DiskDistributionRegistry
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl from llama_stack.providers.utils.kvstore.sqlite import SqliteKVStoreImpl
@pytest_asyncio.fixture(scope="function") @pytest.fixture(scope="function")
async def sqlite_kvstore(tmp_path): async def sqlite_kvstore(tmp_path):
db_path = tmp_path / "test_kv.db" db_path = tmp_path / "test_kv.db"
kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix()) kvstore_config = SqliteKVStoreConfig(db_path=db_path.as_posix())
@ -20,14 +20,14 @@ async def sqlite_kvstore(tmp_path):
yield kvstore yield kvstore
@pytest_asyncio.fixture(scope="function") @pytest.fixture(scope="function")
async def disk_dist_registry(sqlite_kvstore): async def disk_dist_registry(sqlite_kvstore):
registry = DiskDistributionRegistry(sqlite_kvstore) registry = DiskDistributionRegistry(sqlite_kvstore)
await registry.initialize() await registry.initialize()
yield registry yield registry
@pytest_asyncio.fixture(scope="function") @pytest.fixture(scope="function")
async def cached_disk_dist_registry(sqlite_kvstore): async def cached_disk_dist_registry(sqlite_kvstore):
registry = CachedDiskDistributionRegistry(sqlite_kvstore) registry = CachedDiskDistributionRegistry(sqlite_kvstore)
await registry.initialize() await registry.initialize()

View file

@ -8,7 +8,6 @@ from datetime import datetime
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
Agent, Agent,
@ -50,7 +49,7 @@ def config(tmp_path):
) )
@pytest_asyncio.fixture @pytest.fixture
async def agents_impl(config, mock_apis): async def agents_impl(config, mock_apis):
impl = MetaReferenceAgentsImpl( impl = MetaReferenceAgentsImpl(
config, config,
@ -117,7 +116,6 @@ def sample_agent_config():
) )
@pytest.mark.asyncio
async def test_create_agent(agents_impl, sample_agent_config): async def test_create_agent(agents_impl, sample_agent_config):
response = await agents_impl.create_agent(sample_agent_config) response = await agents_impl.create_agent(sample_agent_config)
@ -132,7 +130,6 @@ async def test_create_agent(agents_impl, sample_agent_config):
assert isinstance(agent_info.created_at, datetime) assert isinstance(agent_info.created_at, datetime)
@pytest.mark.asyncio
async def test_get_agent(agents_impl, sample_agent_config): async def test_get_agent(agents_impl, sample_agent_config):
create_response = await agents_impl.create_agent(sample_agent_config) create_response = await agents_impl.create_agent(sample_agent_config)
agent_id = create_response.agent_id agent_id = create_response.agent_id
@ -146,7 +143,6 @@ async def test_get_agent(agents_impl, sample_agent_config):
assert isinstance(agent.created_at, datetime) assert isinstance(agent.created_at, datetime)
@pytest.mark.asyncio
async def test_list_agents(agents_impl, sample_agent_config): async def test_list_agents(agents_impl, sample_agent_config):
agent1_response = await agents_impl.create_agent(sample_agent_config) agent1_response = await agents_impl.create_agent(sample_agent_config)
agent2_response = await agents_impl.create_agent(sample_agent_config) agent2_response = await agents_impl.create_agent(sample_agent_config)
@ -160,7 +156,6 @@ async def test_list_agents(agents_impl, sample_agent_config):
assert agent2_response.agent_id in agent_ids assert agent2_response.agent_id in agent_ids
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False]) @pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence): async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting # Create an agent with specified persistence setting
@ -188,7 +183,6 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
await agents_impl.get_agents_session(agent_id, session_response.session_id) await agents_impl.get_agents_session(agent_id, session_response.session_id)
@pytest.mark.asyncio
@pytest.mark.parametrize("enable_session_persistence", [True, False]) @pytest.mark.parametrize("enable_session_persistence", [True, False])
async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence): async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence):
# Create an agent with specified persistence setting # Create an agent with specified persistence setting
@ -221,7 +215,6 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
assert session2.session_id in {s["session_id"] for s in sessions.data} assert session2.session_id in {s["session_id"] for s in sessions.data}
@pytest.mark.asyncio
async def test_delete_agent(agents_impl, sample_agent_config): async def test_delete_agent(agents_impl, sample_agent_config):
# Create an agent # Create an agent
response = await agents_impl.create_agent(sample_agent_config) response = await agents_impl.create_agent(sample_agent_config)

View file

@ -122,7 +122,6 @@ async def fake_stream(fixture: str = "simple_chat_completion.yaml"):
) )
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_string_input(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input.""" """Test creating an OpenAI response with a simple string input."""
# Setup # Setup
@ -155,7 +154,6 @@ async def test_create_openai_response_with_string_input(openai_responses_impl, m
assert result.output[0].content[0].text == "Dublin" assert result.output[0].content[0].text == "Dublin"
@pytest.mark.asyncio
async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_string_input_with_tools(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a simple string input and tools.""" """Test creating an OpenAI response with a simple string input and tools."""
# Setup # Setup
@ -224,7 +222,6 @@ async def test_create_openai_response_with_string_input_with_tools(openai_respon
assert result.output[1].content[0].annotations == [] assert result.output[1].content[0].annotations == []
@pytest.mark.asyncio
async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_tool_call_type_none(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with a tool call response that has a type of None.""" """Test creating an OpenAI response with a tool call response that has a type of None."""
# Setup # Setup
@ -294,7 +291,6 @@ async def test_create_openai_response_with_tool_call_type_none(openai_responses_
assert chunks[1].response.output[0].name == "get_weather" assert chunks[1].response.output[0].name == "get_weather"
@pytest.mark.asyncio
async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_multiple_messages(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with multiple messages.""" """Test creating an OpenAI response with multiple messages."""
# Setup # Setup
@ -340,7 +336,6 @@ async def test_create_openai_response_with_multiple_messages(openai_responses_im
assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam) assert isinstance(inference_messages[i], OpenAIDeveloperMessageParam)
@pytest.mark.asyncio
async def test_prepend_previous_response_none(openai_responses_impl): async def test_prepend_previous_response_none(openai_responses_impl):
"""Test prepending no previous response to a new response.""" """Test prepending no previous response to a new response."""
@ -348,7 +343,6 @@ async def test_prepend_previous_response_none(openai_responses_impl):
assert input == "fake_input" assert input == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store): async def test_prepend_previous_response_basic(openai_responses_impl, mock_responses_store):
"""Test prepending a basic previous response to a new response.""" """Test prepending a basic previous response to a new response."""
@ -388,7 +382,6 @@ async def test_prepend_previous_response_basic(openai_responses_impl, mock_respo
assert input[2].content == "fake_input" assert input[2].content == "fake_input"
@pytest.mark.asyncio
async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store): async def test_prepend_previous_response_web_search(openai_responses_impl, mock_responses_store):
"""Test prepending a web search previous response to a new response.""" """Test prepending a web search previous response to a new response."""
input_item_message = OpenAIResponseMessage( input_item_message = OpenAIResponseMessage(
@ -434,7 +427,6 @@ async def test_prepend_previous_response_web_search(openai_responses_impl, mock_
assert input[3].content == "fake_input" assert input[3].content == "fake_input"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_instructions(openai_responses_impl, mock_inference_api):
# Setup # Setup
input_text = "What is the capital of Ireland?" input_text = "What is the capital of Ireland?"
@ -463,7 +455,6 @@ async def test_create_openai_response_with_instructions(openai_responses_impl, m
assert sent_messages[1].content == input_text assert sent_messages[1].content == input_text
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_multiple_messages( async def test_create_openai_response_with_instructions_and_multiple_messages(
openai_responses_impl, mock_inference_api openai_responses_impl, mock_inference_api
): ):
@ -508,7 +499,6 @@ async def test_create_openai_response_with_instructions_and_multiple_messages(
assert sent_messages[3].content == "Which is the largest?" assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_create_openai_response_with_instructions_and_previous_response( async def test_create_openai_response_with_instructions_and_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api openai_responses_impl, mock_responses_store, mock_inference_api
): ):
@ -565,7 +555,6 @@ async def test_create_openai_response_with_instructions_and_previous_response(
assert sent_messages[3].content == "Which is the largest?" assert sent_messages[3].content == "Which is the largest?"
@pytest.mark.asyncio
async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store): async def test_list_openai_response_input_items_delegation(openai_responses_impl, mock_responses_store):
"""Test that list_openai_response_input_items properly delegates to responses_store with correct parameters.""" """Test that list_openai_response_input_items properly delegates to responses_store with correct parameters."""
# Setup # Setup
@ -601,7 +590,6 @@ async def test_list_openai_response_input_items_delegation(openai_responses_impl
assert result.data[0].id == "msg_123" assert result.data[0].id == "msg_123"
@pytest.mark.asyncio
async def test_responses_store_list_input_items_logic(): async def test_responses_store_list_input_items_logic():
"""Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting.""" """Test ResponsesStore list_response_input_items logic - mocks get_response_object to test actual ordering/limiting."""
@ -680,7 +668,6 @@ async def test_responses_store_list_input_items_logic():
assert len(result.data) == 0 # Should return no items assert len(result.data) == 0 # Should return no items
@pytest.mark.asyncio
async def test_store_response_uses_rehydrated_input_with_previous_response( async def test_store_response_uses_rehydrated_input_with_previous_response(
openai_responses_impl, mock_responses_store, mock_inference_api openai_responses_impl, mock_responses_store, mock_inference_api
): ):
@ -747,7 +734,6 @@ async def test_store_response_uses_rehydrated_input_with_previous_response(
assert result.status == "completed" assert result.status == "completed"
@pytest.mark.asyncio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"text_format, response_format", "text_format, response_format",
[ [
@ -787,7 +773,6 @@ async def test_create_openai_response_with_text_format(
assert first_call.kwargs["response_format"] == response_format assert first_call.kwargs["response_format"] == response_format
@pytest.mark.asyncio
async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api): async def test_create_openai_response_with_invalid_text_format(openai_responses_impl, mock_inference_api):
"""Test creating an OpenAI response with an invalid text format.""" """Test creating an OpenAI response with an invalid text format."""
# Setup # Setup

View file

@ -9,7 +9,6 @@ from datetime import datetime
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.agents import Turn from llama_stack.apis.agents import Turn
from llama_stack.apis.inference import CompletionMessage, StopReason from llama_stack.apis.inference import CompletionMessage, StopReason
@ -17,13 +16,12 @@ from llama_stack.distribution.datatypes import User
from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo from llama_stack.providers.inline.agents.meta_reference.persistence import AgentPersistence, AgentSessionInfo
@pytest_asyncio.fixture @pytest.fixture
async def test_setup(sqlite_kvstore): async def test_setup(sqlite_kvstore):
agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={}) agent_persistence = AgentPersistence(agent_id="test_agent", kvstore=sqlite_kvstore, policy={})
yield agent_persistence yield agent_persistence
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup): async def test_session_creation_with_access_attributes(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
@ -44,7 +42,6 @@ async def test_session_creation_with_access_attributes(mock_get_authenticated_us
assert session_info.owner.attributes["teams"] == ["ai-team"] assert session_info.owner.attributes["teams"] == ["ai-team"]
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_session_access_control(mock_get_authenticated_user, test_setup): async def test_session_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
@ -79,7 +76,6 @@ async def test_session_access_control(mock_get_authenticated_user, test_setup):
assert retrieved_session is None assert retrieved_session is None
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_turn_access_control(mock_get_authenticated_user, test_setup): async def test_turn_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup
@ -133,7 +129,6 @@ async def test_turn_access_control(mock_get_authenticated_user, test_setup):
await agent_persistence.get_session_turns(session_id) await agent_persistence.get_session_turns(session_id)
@pytest.mark.asyncio
@patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user") @patch("llama_stack.providers.inline.agents.meta_reference.persistence.get_authenticated_user")
async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup): async def test_tool_call_and_infer_iters_access_control(mock_get_authenticated_user, test_setup):
agent_persistence = test_setup agent_persistence = test_setup

View file

@ -0,0 +1,73 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from unittest.mock import MagicMock
from llama_stack.distribution.request_headers import request_provider_data_context
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter
from llama_stack.providers.remote.inference.openai.config import OpenAIConfig
from llama_stack.providers.remote.inference.openai.openai import OpenAIInferenceAdapter
from llama_stack.providers.remote.inference.together.config import TogetherImplConfig
from llama_stack.providers.remote.inference.together.together import TogetherInferenceAdapter
def test_groq_provider_openai_client_caching():
"""Ensure the Groq provider does not cache api keys across client requests"""
config = GroqConfig()
inference_adapter = GroqInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_openai_provider_openai_client_caching():
"""Ensure the OpenAI provider does not cache api keys across client requests"""
config = OpenAIConfig()
inference_adapter = OpenAIInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context(
{"x-llamastack-provider-data": json.dumps({inference_adapter.provider_data_api_key_field: api_key})}
):
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key
def test_together_provider_openai_client_caching():
"""Ensure the Together provider does not cache api keys across client requests"""
config = TogetherImplConfig()
inference_adapter = TogetherInferenceAdapter(config)
inference_adapter.__provider_spec__ = MagicMock()
inference_adapter.__provider_spec__.provider_data_validator = (
"llama_stack.providers.remote.inference.together.TogetherProviderDataValidator"
)
for api_key in ["test1", "test2"]:
with request_provider_data_context({"x-llamastack-provider-data": json.dumps({"together_api_key": api_key})}):
together_client = inference_adapter._get_client()
assert together_client.client.api_key == api_key
openai_client = inference_adapter._get_openai_client()
assert openai_client.api_key == api_key

View file

@ -14,7 +14,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio
from openai.types.chat.chat_completion_chunk import ( from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk, ChatCompletionChunk as OpenAIChatCompletionChunk,
) )
@ -103,7 +102,7 @@ def mock_openai_models_list():
yield mock_list yield mock_list
@pytest_asyncio.fixture(scope="module") @pytest.fixture(scope="module")
async def vllm_inference_adapter(): async def vllm_inference_adapter():
config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345") config = VLLMInferenceAdapterConfig(url="http://mocked.localhost:12345")
inference_adapter = VLLMInferenceAdapter(config) inference_adapter = VLLMInferenceAdapter(config)
@ -112,7 +111,6 @@ async def vllm_inference_adapter():
return inference_adapter return inference_adapter
@pytest.mark.asyncio
async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter): async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inference_adapter):
async def mock_openai_models(): async def mock_openai_models():
yield OpenAIModel(id="foo", created=1, object="model", owned_by="test") yield OpenAIModel(id="foo", created=1, object="model", owned_by="test")
@ -125,7 +123,6 @@ async def test_register_model_checks_vllm(mock_openai_models_list, vllm_inferenc
mock_openai_models_list.assert_called() mock_openai_models_list.assert_called()
@pytest.mark.asyncio
async def test_old_vllm_tool_choice(vllm_inference_adapter): async def test_old_vllm_tool_choice(vllm_inference_adapter):
""" """
Test that we set tool_choice to none when no tools are in use Test that we set tool_choice to none when no tools are in use
@ -149,7 +146,6 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert request.tool_config.tool_choice == ToolChoice.none assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_response(vllm_inference_adapter): async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted """Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format.""" into the expected JSON format."""
@ -192,7 +188,6 @@ async def test_tool_call_response(vllm_inference_adapter):
] ]
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf(): async def test_tool_call_delta_empty_tool_call_buf():
""" """
Test that we don't generate extra chunks when processing a Test that we don't generate extra chunks when processing a
@ -222,7 +217,6 @@ async def test_tool_call_delta_empty_tool_call_buf():
assert chunks[1].event.stop_reason == StopReason.end_of_turn assert chunks[1].event.stop_reason == StopReason.end_of_turn
@pytest.mark.asyncio
async def test_tool_call_delta_streaming_arguments_dict(): async def test_tool_call_delta_streaming_arguments_dict():
async def mock_stream(): async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk( mock_chunk_1 = OpenAIChatCompletionChunk(
@ -297,7 +291,6 @@ async def test_tool_call_delta_streaming_arguments_dict():
assert chunks[2].event.event_type.value == "complete" assert chunks[2].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_multiple_tool_calls(): async def test_multiple_tool_calls():
async def mock_stream(): async def mock_stream():
mock_chunk_1 = OpenAIChatCompletionChunk( mock_chunk_1 = OpenAIChatCompletionChunk(
@ -376,7 +369,6 @@ async def test_multiple_tool_calls():
assert chunks[3].event.event_type.value == "complete" assert chunks[3].event.event_type.value == "complete"
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_choices(): async def test_process_vllm_chat_completion_stream_response_no_choices():
""" """
Test that we don't error out when vLLM returns no choices for a Test that we don't error out when vLLM returns no choices for a
@ -401,6 +393,7 @@ async def test_process_vllm_chat_completion_stream_response_no_choices():
assert chunks[0].event.event_type.value == "start" assert chunks[0].event.event_type.value == "start"
@pytest.mark.allow_network
def test_chat_completion_doesnt_block_event_loop(caplog): def test_chat_completion_doesnt_block_event_loop(caplog):
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
loop.set_debug(True) loop.set_debug(True)
@ -453,7 +446,6 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
assert not asyncio_warnings assert not asyncio_warnings
@pytest.mark.asyncio
async def test_get_params_empty_tools(vllm_inference_adapter): async def test_get_params_empty_tools(vllm_inference_adapter):
request = ChatCompletionRequest( request = ChatCompletionRequest(
tools=[], tools=[],
@ -464,7 +456,6 @@ async def test_get_params_empty_tools(vllm_inference_adapter):
assert "tools" not in params assert "tools" not in params
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk(): async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_chunk():
""" """
Tests the edge case where the model returns the arguments for the tool call in the same chunk that Tests the edge case where the model returns the arguments for the tool call in the same chunk that
@ -543,7 +534,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_call_args_last_
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_no_finish_reason(): async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
""" """
Tests the edge case where the model requests a tool call and stays idle without explicitly providing the Tests the edge case where the model requests a tool call and stays idle without explicitly providing the
@ -596,7 +586,6 @@ async def test_process_vllm_chat_completion_stream_response_no_finish_reason():
assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments assert chunks[-2].event.delta.tool_call.arguments == mock_tool_arguments
@pytest.mark.asyncio
async def test_process_vllm_chat_completion_stream_response_tool_without_args(): async def test_process_vllm_chat_completion_stream_response_tool_without_args():
""" """
Tests the edge case where no arguments are provided for the tool call. Tests the edge case where no arguments are provided for the tool call.
@ -645,7 +634,6 @@ async def test_process_vllm_chat_completion_stream_response_tool_without_args():
assert chunks[-2].event.delta.tool_call.arguments == {} assert chunks[-2].event.delta.tool_call.arguments == {}
@pytest.mark.asyncio
async def test_health_status_success(vllm_inference_adapter): async def test_health_status_success(vllm_inference_adapter):
""" """
Test the health method of VLLM InferenceAdapter when the connection is successful. Test the health method of VLLM InferenceAdapter when the connection is successful.
@ -679,7 +667,6 @@ async def test_health_status_success(vllm_inference_adapter):
mock_models.list.assert_called_once() mock_models.list.assert_called_once()
@pytest.mark.asyncio
async def test_health_status_failure(vllm_inference_adapter): async def test_health_status_failure(vllm_inference_adapter):
""" """
Test the health method of VLLM InferenceAdapter when the connection fails. Test the health method of VLLM InferenceAdapter when the connection fails.

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack.apis.common.content_types import TextContentItem from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
@ -23,7 +22,6 @@ from llama_stack.providers.utils.inference.openai_compat import (
) )
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict(): async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == { assert await convert_message_to_openai_dict(message) == {
@ -33,7 +31,6 @@ async def test_convert_message_to_openai_dict():
# Test convert_message_to_openai_dict with a tool call # Test convert_message_to_openai_dict with a tool call
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_tool_call(): async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage( message = CompletionMessage(
content="", content="",
@ -54,7 +51,6 @@ async def test_convert_message_to_openai_dict_with_tool_call():
} }
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_builtin_tool_call(): async def test_convert_message_to_openai_dict_with_builtin_tool_call():
message = CompletionMessage( message = CompletionMessage(
content="", content="",
@ -80,7 +76,6 @@ async def test_convert_message_to_openai_dict_with_builtin_tool_call():
} }
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_str(): async def test_openai_messages_to_messages_with_content_str():
openai_messages = [ openai_messages = [
OpenAISystemMessageParam(content="system message"), OpenAISystemMessageParam(content="system message"),
@ -98,7 +93,6 @@ async def test_openai_messages_to_messages_with_content_str():
assert llama_messages[2].content == "assistant message" assert llama_messages[2].content == "assistant message"
@pytest.mark.asyncio
async def test_openai_messages_to_messages_with_content_list(): async def test_openai_messages_to_messages_with_content_list():
openai_messages = [ openai_messages = [
OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]), OpenAISystemMessageParam(content=[OpenAIChatCompletionContentPartTextParam(text="system message")]),

View file

@ -13,7 +13,6 @@ from llama_stack.apis.tools import RAGDocument
from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc from llama_stack.providers.utils.memory.vector_store import content_from_data_and_mime_type, content_from_doc
@pytest.mark.asyncio
async def test_content_from_doc_with_url(): async def test_content_from_doc_with_url():
"""Test extracting content from RAGDocument with URL content.""" """Test extracting content from RAGDocument with URL content."""
mock_url = URL(uri="https://example.com") mock_url = URL(uri="https://example.com")
@ -33,7 +32,6 @@ async def test_content_from_doc_with_url():
mock_instance.get.assert_called_once_with(mock_url.uri) mock_instance.get.assert_called_once_with(mock_url.uri)
@pytest.mark.asyncio
async def test_content_from_doc_with_pdf_url(): async def test_content_from_doc_with_pdf_url():
"""Test extracting content from RAGDocument with URL pointing to a PDF.""" """Test extracting content from RAGDocument with URL pointing to a PDF."""
mock_url = URL(uri="https://example.com/document.pdf") mock_url = URL(uri="https://example.com/document.pdf")
@ -58,7 +56,6 @@ async def test_content_from_doc_with_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data") mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_data_url(): async def test_content_from_doc_with_data_url():
"""Test extracting content from RAGDocument with data URL content.""" """Test extracting content from RAGDocument with data URL content."""
data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded data_url = "data:text/plain;base64,SGVsbG8gV29ybGQ=" # "Hello World" base64 encoded
@ -74,7 +71,6 @@ async def test_content_from_doc_with_data_url():
mock_content_from_data.assert_called_once_with(data_url) mock_content_from_data.assert_called_once_with(data_url)
@pytest.mark.asyncio
async def test_content_from_doc_with_string(): async def test_content_from_doc_with_string():
"""Test extracting content from RAGDocument with string content.""" """Test extracting content from RAGDocument with string content."""
content_string = "This is plain text content" content_string = "This is plain text content"
@ -85,7 +81,6 @@ async def test_content_from_doc_with_string():
assert result == content_string assert result == content_string
@pytest.mark.asyncio
async def test_content_from_doc_with_string_url(): async def test_content_from_doc_with_string_url():
"""Test extracting content from RAGDocument with string URL content.""" """Test extracting content from RAGDocument with string URL content."""
url_string = "https://example.com" url_string = "https://example.com"
@ -105,7 +100,6 @@ async def test_content_from_doc_with_string_url():
mock_instance.get.assert_called_once_with(url_string) mock_instance.get.assert_called_once_with(url_string)
@pytest.mark.asyncio
async def test_content_from_doc_with_string_pdf_url(): async def test_content_from_doc_with_string_pdf_url():
"""Test extracting content from RAGDocument with string URL pointing to a PDF.""" """Test extracting content from RAGDocument with string URL pointing to a PDF."""
url_string = "https://example.com/document.pdf" url_string = "https://example.com/document.pdf"
@ -130,7 +124,6 @@ async def test_content_from_doc_with_string_pdf_url():
mock_parse_pdf.assert_called_once_with(b"PDF binary data") mock_parse_pdf.assert_called_once_with(b"PDF binary data")
@pytest.mark.asyncio
async def test_content_from_doc_with_interleaved_content(): async def test_content_from_doc_with_interleaved_content():
"""Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit).""" """Test extracting content from RAGDocument with InterleavedContent (the new case added in the commit)."""
interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")] interleaved_content = [TextContentItem(text="First item"), TextContentItem(text="Second item")]

View file

@ -94,8 +94,8 @@ class MockModelRegistryHelperWithDynamicModels(ModelRegistryHelper):
super().__init__(model_entries) super().__init__(model_entries)
self._available_models = available_models self._available_models = available_models
async def query_available_models(self) -> list[str]: async def check_model_availability(self, model: str) -> bool:
return self._available_models return model in self._available_models
@pytest.fixture @pytest.fixture
@ -118,18 +118,15 @@ def helper_with_dynamic_models(
) )
@pytest.mark.asyncio
async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: async def test_lookup_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
assert helper.get_provider_model_id(unknown_model.model_id) is None assert helper.get_provider_model_id(unknown_model.model_id) is None
@pytest.mark.asyncio
async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: async def test_register_unknown_provider_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await helper.register_model(unknown_model) await helper.register_model(unknown_model)
@pytest.mark.asyncio
async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model( model = Model(
provider_id=known_model.provider_id, provider_id=known_model.provider_id,
@ -141,7 +138,6 @@ async def test_register_model(helper: ModelRegistryHelper, known_model: Model) -
assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id assert helper.get_provider_model_id(model.model_id) == model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_from_alias(helper: ModelRegistryHelper, known_model: Model) -> None:
model = Model( model = Model(
provider_id=known_model.provider_id, provider_id=known_model.provider_id,
@ -153,13 +149,11 @@ async def test_register_model_from_alias(helper: ModelRegistryHelper, known_mode
assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id assert helper.get_provider_model_id(model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_existing(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model) await helper.register_model(known_model)
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_resource_id
@pytest.mark.asyncio
async def test_register_model_existing_different( async def test_register_model_existing_different(
helper: ModelRegistryHelper, known_model: Model, known_model2: Model helper: ModelRegistryHelper, known_model: Model, known_model2: Model
) -> None: ) -> None:
@ -168,7 +162,6 @@ async def test_register_model_existing_different(
await helper.register_model(known_model) await helper.register_model(known_model)
@pytest.mark.asyncio
async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model) -> None:
await helper.register_model(known_model) # duplicate entry await helper.register_model(known_model) # duplicate entry
assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.model_id) == known_model.provider_model_id
@ -176,35 +169,31 @@ async def test_unregister_model(helper: ModelRegistryHelper, known_model: Model)
assert helper.get_provider_model_id(known_model.model_id) is None assert helper.get_provider_model_id(known_model.model_id) is None
@pytest.mark.asyncio
async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None: async def test_unregister_unknown_model(helper: ModelRegistryHelper, unknown_model: Model) -> None:
with pytest.raises(ValueError): with pytest.raises(ValueError):
await helper.unregister_model(unknown_model.model_id) await helper.unregister_model(unknown_model.model_id)
@pytest.mark.asyncio
async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_register_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
@pytest.mark.asyncio
async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None: async def test_unregister_model_during_init(helper: ModelRegistryHelper, known_model: Model) -> None:
assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id assert helper.get_provider_model_id(known_model.provider_resource_id) == known_model.provider_model_id
await helper.unregister_model(known_model.provider_resource_id) await helper.unregister_model(known_model.provider_resource_id)
assert helper.get_provider_model_id(known_model.provider_resource_id) is None assert helper.get_provider_model_id(known_model.provider_resource_id) is None
@pytest.mark.asyncio async def test_register_model_from_check_model_availability(
async def test_register_model_from_query_available_models(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None: ) -> None:
"""Test that models returned by query_available_models can be registered.""" """Test that models returned by check_model_availability can be registered."""
# Verify the model is not in static config # Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None assert helper_with_dynamic_models.get_provider_model_id(dynamic_model.provider_resource_id) is None
# But it should be available via query_available_models # But it should be available via check_model_availability
available_models = await helper_with_dynamic_models.query_available_models() is_available = await helper_with_dynamic_models.check_model_availability(dynamic_model.provider_resource_id)
assert dynamic_model.provider_resource_id in available_models assert is_available
# Registration should succeed # Registration should succeed
registered_model = await helper_with_dynamic_models.register_model(dynamic_model) registered_model = await helper_with_dynamic_models.register_model(dynamic_model)
@ -216,7 +205,6 @@ async def test_register_model_from_query_available_models(
) )
@pytest.mark.asyncio
async def test_register_model_not_in_static_or_dynamic( async def test_register_model_not_in_static_or_dynamic(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, unknown_model: Model
) -> None: ) -> None:
@ -224,20 +212,19 @@ async def test_register_model_not_in_static_or_dynamic(
# Verify the model is not in static config # Verify the model is not in static config
assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None assert helper_with_dynamic_models.get_provider_model_id(unknown_model.provider_resource_id) is None
# And not in dynamic models # And not available via check_model_availability
available_models = await helper_with_dynamic_models.query_available_models() is_available = await helper_with_dynamic_models.check_model_availability(unknown_model.provider_resource_id)
assert unknown_model.provider_resource_id not in available_models assert not is_available
# Registration should fail with comprehensive error message # Registration should fail with comprehensive error message
with pytest.raises(Exception) as exc_info: # UnsupportedModelError with pytest.raises(Exception) as exc_info: # UnsupportedModelError
await helper_with_dynamic_models.register_model(unknown_model) await helper_with_dynamic_models.register_model(unknown_model)
# Error should include both static and dynamic models # Error should include static models and "..." for dynamic models
error_str = str(exc_info.value) error_str = str(exc_info.value)
assert "dynamic-provider-id" in error_str # dynamic model should be in error assert "..." in error_str # "..." should be in error message
@pytest.mark.asyncio
async def test_register_alias_for_dynamic_model( async def test_register_alias_for_dynamic_model(
helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model helper_with_dynamic_models: MockModelRegistryHelperWithDynamicModels, dynamic_model: Model
) -> None: ) -> None:

View file

@ -11,7 +11,6 @@ import pytest
from llama_stack.providers.utils.scheduler import JobStatus, Scheduler from llama_stack.providers.utils.scheduler import JobStatus, Scheduler
@pytest.mark.asyncio
async def test_scheduler_unknown_backend(): async def test_scheduler_unknown_backend():
with pytest.raises(ValueError): with pytest.raises(ValueError):
Scheduler(backend="unknown") Scheduler(backend="unknown")
@ -26,7 +25,6 @@ async def wait_for_job_completed(sched: Scheduler, job_id: str) -> None:
raise TimeoutError(f"Job {job_id} did not complete in time.") raise TimeoutError(f"Job {job_id} did not complete in time.")
@pytest.mark.asyncio
async def test_scheduler_naive(): async def test_scheduler_naive():
sched = Scheduler() sched = Scheduler()
@ -87,7 +85,6 @@ async def test_scheduler_naive():
assert job.logs[0][0] < job.logs[1][0] assert job.logs[0][0] < job.logs[1][0]
@pytest.mark.asyncio
async def test_scheduler_naive_handler_raises(): async def test_scheduler_naive_handler_raises():
sched = Scheduler() sched = Scheduler()

View file

@ -8,10 +8,20 @@ import random
import numpy as np import numpy as np
import pytest import pytest
from pymilvus import MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata from llama_stack.apis.vector_io import Chunk, ChunkMetadata
from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig
from llama_stack.providers.inline.vector_io.faiss.faiss import FaissIndex, FaissVectorIOAdapter
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig
from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import SQLiteVecIndex, SQLiteVecVectorIOAdapter
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex, MilvusVectorIOAdapter
EMBEDDING_DIMENSION = 384 EMBEDDING_DIMENSION = 384
COLLECTION_PREFIX = "test_collection"
MILVUS_ALIAS = "test_milvus"
@pytest.fixture @pytest.fixture
@ -50,7 +60,194 @@ def sample_chunks():
return sample return sample
@pytest.fixture(scope="session")
def sample_chunks_with_metadata():
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
n, k = 10, 3
sample = [
Chunk(
content=f"Sentence {i} from document {j}",
metadata={"document_id": f"document-{j}"},
chunk_metadata=ChunkMetadata(
document_id=f"document-{j}",
chunk_id=f"document-{j}-chunk-{i}",
source=f"example source-{j}-{i}",
),
)
for j in range(k)
for i in range(n)
]
return sample
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def sample_embeddings(sample_chunks): def sample_embeddings(sample_chunks):
np.random.seed(42) np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks]) return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks])
@pytest.fixture(scope="session")
def sample_embeddings_with_metadata(sample_chunks_with_metadata):
np.random.seed(42)
return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks_with_metadata])
@pytest.fixture(params=["milvus", "sqlite_vec", "faiss"])
def vector_provider(request):
return request.param
@pytest.fixture(scope="session")
def mock_inference_api(embedding_dimension):
class MockInferenceAPI:
async def embed_batch(self, texts: list[str]) -> list[list[float]]:
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts]
return MockInferenceAPI()
@pytest.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest.fixture(scope="session")
def sqlite_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_sqlite_vec.db")
return db_path
@pytest.fixture
async def sqlite_vec_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"test_sqlite_vec_{np.random.randint(1e6)}.db")
bank_id = f"sqlite_vec_bank_{np.random.randint(1e6)}"
index = SQLiteVecIndex(embedding_dimension, db_path, bank_id)
await index.initialize()
index.db_path = db_path
yield index
index.delete()
@pytest.fixture
async def sqlite_vec_adapter(sqlite_vec_db_path, mock_inference_api, embedding_dimension):
config = SQLiteVectorIOConfig(
db_path=sqlite_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = SQLiteVecVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
collection_id = f"sqlite_test_collection_{np.random.randint(1e6)}"
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=collection_id,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
adapter.test_collection_id = collection_id
yield adapter
await adapter.shutdown()
@pytest.fixture(scope="session")
def milvus_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_milvus.db")
return db_path
@pytest.fixture
async def milvus_vec_index(milvus_vec_db_path, embedding_dimension):
client = MilvusClient(milvus_vec_db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = milvus_vec_db_path
yield index
@pytest.fixture
async def milvus_vec_adapter(milvus_vec_db_path, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def faiss_vec_db_path(tmp_path_factory):
db_path = str(tmp_path_factory.getbasetemp() / "test_faiss.db")
return db_path
@pytest.fixture
async def faiss_vec_index(embedding_dimension):
index = FaissIndex(embedding_dimension)
yield index
await index.delete()
@pytest.fixture
async def faiss_vec_adapter(unique_kvstore_config, mock_inference_api, embedding_dimension):
config = FaissVectorIOConfig(
kvstore=unique_kvstore_config,
)
adapter = FaissVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=f"faiss_test_collection_{np.random.randint(1e6)}",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=embedding_dimension,
)
)
yield adapter
await adapter.shutdown()
@pytest.fixture
def vector_io_adapter(vector_provider, request):
"""Returns the appropriate vector IO adapter based on the provider parameter."""
if vector_provider == "milvus":
return request.getfixturevalue("milvus_vec_adapter")
elif vector_provider == "faiss":
return request.getfixturevalue("faiss_vec_adapter")
else:
return request.getfixturevalue("sqlite_vec_adapter")
@pytest.fixture
def vector_index(vector_provider, request):
"""Returns appropriate vector index based on provider parameter"""
return request.getfixturevalue(f"{vector_provider}_vec_index")

View file

@ -9,7 +9,6 @@ from unittest.mock import AsyncMock, MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.files import Files from llama_stack.apis.files import Files
from llama_stack.apis.inference import EmbeddingsResponse, Inference from llama_stack.apis.inference import EmbeddingsResponse, Inference
@ -91,13 +90,13 @@ def faiss_config():
return config return config
@pytest_asyncio.fixture @pytest.fixture
async def faiss_index(embedding_dimension): async def faiss_index(embedding_dimension):
index = await FaissIndex.create(dimension=embedding_dimension) index = await FaissIndex.create(dimension=embedding_dimension)
yield index yield index
@pytest_asyncio.fixture @pytest.fixture
async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter: async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> FaissVectorIOAdapter:
# Create the adapter # Create the adapter
adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api) adapter = FaissVectorIOAdapter(config=faiss_config, inference_api=mock_inference_api, files_api=mock_files_api)
@ -113,7 +112,6 @@ async def faiss_adapter(faiss_config, mock_inference_api, mock_files_api) -> Fai
yield adapter yield adapter
@pytest.mark.asyncio
async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical( async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_identical(
faiss_index, sample_chunks, sample_embeddings, embedding_dimension faiss_index, sample_chunks, sample_embeddings, embedding_dimension
): ):
@ -136,7 +134,6 @@ async def test_faiss_query_vector_returns_infinity_when_query_and_embedding_are_
assert response.chunks[1] == sample_chunks[1] assert response.chunks[1] == sample_chunks[1]
@pytest.mark.asyncio
async def test_health_success(): async def test_health_success():
"""Test that the health check returns OK status when faiss is working correctly.""" """Test that the health check returns OK status when faiss is working correctly."""
# Create a fresh instance of FaissVectorIOAdapter for testing # Create a fresh instance of FaissVectorIOAdapter for testing
@ -160,7 +157,6 @@ async def test_health_success():
mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128 mock_index_flat.assert_called_once_with(128) # VECTOR_DIMENSION is 128
@pytest.mark.asyncio
async def test_health_failure(): async def test_health_failure():
"""Test that the health check returns ERROR status when faiss encounters an error.""" """Test that the health check returns ERROR status when faiss encounters an error."""
# Create a fresh instance of FaissVectorIOAdapter for testing # Create a fresh instance of FaissVectorIOAdapter for testing

View file

@ -10,7 +10,6 @@ from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.inference import EmbeddingsResponse, Inference from llama_stack.apis.inference import EmbeddingsResponse, Inference
from llama_stack.apis.vector_io import ( from llama_stack.apis.vector_io import (
@ -68,7 +67,7 @@ def mock_api_service(sample_embeddings):
return mock_api_service return mock_api_service
@pytest_asyncio.fixture @pytest.fixture
async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter: async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter:
adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service) adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service)
adapter.vector_db_store = mock_vector_db_store adapter.vector_db_store = mock_vector_db_store
@ -80,7 +79,6 @@ async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service,
__QUERY = "Sample query" __QUERY = "Sample query"
@pytest.mark.asyncio
@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)]) @pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 60)])
async def test_qdrant_adapter_returns_expected_chunks( async def test_qdrant_adapter_returns_expected_chunks(
qdrant_adapter: QdrantVectorIOAdapter, qdrant_adapter: QdrantVectorIOAdapter,
@ -111,7 +109,6 @@ def _prepare_for_json(value: Any) -> str:
@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json) @patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json)
@pytest.mark.asyncio
async def test_qdrant_register_and_unregister_vector_db( async def test_qdrant_register_and_unregister_vector_db(
qdrant_adapter: QdrantVectorIOAdapter, qdrant_adapter: QdrantVectorIOAdapter,
mock_vector_db, mock_vector_db,

View file

@ -8,7 +8,6 @@ import asyncio
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import ( from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import (
@ -34,7 +33,7 @@ def loop():
return asyncio.new_event_loop() return asyncio.new_event_loop()
@pytest_asyncio.fixture(scope="session", autouse=True) @pytest.fixture
async def sqlite_vec_index(embedding_dimension, tmp_path_factory): async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp() temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_sqlite.db") db_path = str(temp_dir / "test_sqlite.db")
@ -43,39 +42,14 @@ async def sqlite_vec_index(embedding_dimension, tmp_path_factory):
await index.delete() await index.delete()
@pytest.mark.asyncio async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks_with_metadata, sample_embeddings_with_metadata):
async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): await sqlite_vec_index.add_chunks(sample_chunks_with_metadata, sample_embeddings_with_metadata)
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=2) response = await sqlite_vec_index.query_vector(sample_embeddings_with_metadata[-1], k=2, score_threshold=0.0)
connection = _create_sqlite_connection(sqlite_vec_index.db_path) assert response.chunks[0].chunk_metadata == sample_chunks_with_metadata[-1].chunk_metadata
cur = connection.cursor()
cur.execute(f"SELECT COUNT(*) FROM {sqlite_vec_index.metadata_table}")
count = cur.fetchone()[0]
assert count == len(sample_chunks)
cur.close()
connection.close()
@pytest.mark.asyncio
async def test_query_chunks_vector(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = np.random.rand(embedding_dimension).astype(np.float32)
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert isinstance(response, QueryChunksResponse)
assert len(response.chunks) == 2
@pytest.mark.xfail(reason="Chunk Metadata not yet supported for SQLite-vec", strict=True)
async def test_query_chunk_metadata(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_embedding = sample_embeddings[0]
response = await sqlite_vec_index.query_vector(query_embedding, k=2, score_threshold=0.0)
assert response.chunks[-1].chunk_metadata == sample_chunks[-1].chunk_metadata
@pytest.mark.asyncio
async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_string = "Sentence 5" query_string = "Sentence 5"
response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string) response = await sqlite_vec_index.query_keyword(k=3, score_threshold=0.0, query_string=query_string)
@ -91,7 +65,6 @@ async def test_query_chunks_full_text_search(sqlite_vec_index, sample_chunks, sa
assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}" assert len(response_no_results.chunks) == 0, f"Expected 0 results, but got {len(response_no_results.chunks)}"
@pytest.mark.asyncio
async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -113,7 +86,6 @@ async def test_query_chunks_hybrid(sqlite_vec_index, sample_chunks, sample_embed
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_index, sample_chunks, sample_embeddings):
# Re-initialize with a clean index # Re-initialize with a clean index
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -126,7 +98,6 @@ async def test_query_chunks_full_text_search_k_greater_than_results(sqlite_vec_i
assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found" assert any("Sentence 1 from document 0" in chunk.content for chunk in response.chunks), "Expected chunk not found"
@pytest.mark.asyncio
async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension): async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension):
"""Test that chunk IDs do not conflict across batches when inserting chunks.""" """Test that chunk IDs do not conflict across batches when inserting chunks."""
# Reduce batch size to force multiple batches for same document # Reduce batch size to force multiple batches for same document
@ -148,7 +119,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime
assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!" assert len(chunk_ids) == len(set(chunk_ids)), "Duplicate chunk IDs detected across batches!"
@pytest_asyncio.fixture(scope="session") @pytest.fixture(scope="session")
async def sqlite_vec_adapter(sqlite_connection): async def sqlite_vec_adapter(sqlite_connection):
config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database config = type("Config", (object,), {"db_path": ":memory:"}) # Mock config with in-memory database
adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None) adapter = SQLiteVecVectorIOAdapter(config=config, inference_api=None)
@ -157,7 +128,6 @@ async def sqlite_vec_adapter(sqlite_connection):
await adapter.shutdown() await adapter.shutdown()
@pytest.mark.asyncio
async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search when keyword search returns no matches - should still return vector results.""" """Test hybrid search when keyword search returns no matches - should still return vector results."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -186,7 +156,6 @@ async def test_query_chunks_hybrid_no_keyword_matches(sqlite_vec_index, sample_c
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with a high score threshold.""" """Test hybrid search with a high score threshold."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -208,7 +177,6 @@ async def test_query_chunks_hybrid_score_threshold(sqlite_vec_index, sample_chun
assert len(response.chunks) == 0 assert len(response.chunks) == 0
@pytest.mark.asyncio
async def test_query_chunks_hybrid_different_embedding( async def test_query_chunks_hybrid_different_embedding(
sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension
): ):
@ -234,7 +202,6 @@ async def test_query_chunks_hybrid_different_embedding(
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test that RRF properly combines rankings when documents appear in both search methods.""" """Test that RRF properly combines rankings when documents appear in both search methods."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -259,7 +226,6 @@ async def test_query_chunks_hybrid_rrf_ranking(sqlite_vec_index, sample_chunks,
assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1)) assert all(response.scores[i] >= response.scores[i + 1] for i in range(len(response.scores) - 1))
@pytest.mark.asyncio
async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -307,7 +273,6 @@ async def test_query_chunks_hybrid_score_selection(sqlite_vec_index, sample_chun
assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF assert response.scores[0] == pytest.approx(2.0 / 61.0, rel=1e-6) # Should behave like RRF
@pytest.mark.asyncio
async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test hybrid search with documents that appear in only one search method.""" """Test hybrid search with documents that appear in only one search method."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -336,7 +301,6 @@ async def test_query_chunks_hybrid_mixed_results(sqlite_vec_index, sample_chunks
assert "document-2" in doc_ids # From keyword search assert "document-2" in doc_ids # From keyword search
@pytest.mark.asyncio
async def test_query_chunks_hybrid_weighted_reranker_parametrization( async def test_query_chunks_hybrid_weighted_reranker_parametrization(
sqlite_vec_index, sample_chunks, sample_embeddings sqlite_vec_index, sample_chunks, sample_embeddings
): ):
@ -392,7 +356,6 @@ async def test_query_chunks_hybrid_weighted_reranker_parametrization(
) )
@pytest.mark.asyncio
async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_chunks, sample_embeddings):
"""Test RRFReRanker with different impact factors.""" """Test RRFReRanker with different impact factors."""
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -424,7 +387,6 @@ async def test_query_chunks_hybrid_rrf_impact_factor(sqlite_vec_index, sample_ch
assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6) assert response.scores[0] == pytest.approx(2.0 / 101.0, rel=1e-6)
@pytest.mark.asyncio
async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings): async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, sample_embeddings):
await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings)
@ -468,7 +430,6 @@ async def test_query_chunks_hybrid_edge_cases(sqlite_vec_index, sample_chunks, s
assert len(response.chunks) <= 100 assert len(response.chunks) <= 100
@pytest.mark.asyncio
async def test_query_chunks_hybrid_tie_breaking( async def test_query_chunks_hybrid_tie_breaking(
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
): ):

View file

@ -4,253 +4,130 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import json
import time import time
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from pymilvus import Collection, MilvusClient, connections
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, QueryChunksResponse
from llama_stack.providers.inline.vector_io.milvus.config import MilvusVectorIOConfig, SqliteKVStoreConfig from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX
from llama_stack.providers.remote.vector_io.milvus.milvus import VECTOR_DBS_PREFIX, MilvusIndex, MilvusVectorIOAdapter
from llama_stack.providers.utils.kvstore import kvstore_impl
# TODO: Refactor these to be for inline vector-io providers # This test is a unit test for the inline VectoerIO providers. This should only contain
MILVUS_ALIAS = "test_milvus" # tests which are specific to this class. More general (API-level) tests should be placed in
COLLECTION_PREFIX = "test_collection" # tests/integration/vector_io/
#
# How to run this test:
#
# pytest tests/unit/providers/vector_io/test_vector_io_openai_vector_stores.py \
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
@pytest.fixture(scope="session") async def test_initialize_index(vector_index):
def loop(): await vector_index.initialize()
return asyncio.new_event_loop()
@pytest.fixture(scope="session") async def test_add_chunks_query_vector(vector_index, sample_chunks, sample_embeddings):
def mock_inference_api(embedding_dimension): vector_index.delete()
class MockInferenceAPI: vector_index.initialize()
async def embed_batch(self, texts: list[str]) -> list[list[float]]: await vector_index.add_chunks(sample_chunks, sample_embeddings)
return [np.random.rand(embedding_dimension).astype(np.float32).tolist() for _ in texts] resp = await vector_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
return MockInferenceAPI()
@pytest_asyncio.fixture
async def unique_kvstore_config(tmp_path_factory):
# Generate a unique filename for this test
unique_id = f"test_kv_{np.random.randint(1e6)}"
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / f"{unique_id}.db")
return SqliteKVStoreConfig(db_path=db_path)
@pytest_asyncio.fixture(scope="session", autouse=True)
async def milvus_vec_index(embedding_dimension, tmp_path_factory):
temp_dir = tmp_path_factory.getbasetemp()
db_path = str(temp_dir / "test_milvus.db")
client = MilvusClient(db_path)
name = f"{COLLECTION_PREFIX}_{np.random.randint(1e6)}"
connections.connect(alias=MILVUS_ALIAS, uri=db_path)
index = MilvusIndex(client, name, consistency_level="Strong")
index.db_path = db_path
yield index
@pytest_asyncio.fixture(scope="session")
async def milvus_vec_adapter(milvus_vec_index, mock_inference_api):
config = MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=SqliteKVStoreConfig(),
)
adapter = MilvusVectorIOAdapter(
config=config,
inference_api=mock_inference_api,
files_api=None,
)
await adapter.initialize()
await adapter.register_vector_db(
VectorDB(
identifier=adapter.metadata_collection_name,
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
)
yield adapter
await adapter.shutdown()
@pytest.mark.asyncio
async def test_cache_contains_initial_collection(milvus_vec_adapter):
coll_name = milvus_vec_adapter.metadata_collection_name
assert coll_name in milvus_vec_adapter.cache
@pytest.mark.asyncio
async def test_add_chunks(milvus_vec_index, sample_chunks, sample_embeddings):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
resp = await milvus_vec_index.query_vector(sample_embeddings[0], k=1, score_threshold=-1)
assert resp.chunks[0].content == sample_chunks[0].content assert resp.chunks[0].content == sample_chunks[0].content
vector_index.delete()
@pytest.mark.asyncio async def test_chunk_id_conflict(vector_index, sample_chunks, embedding_dimension):
async def test_query_chunks_vector(milvus_vec_index, sample_chunks, sample_embeddings, embedding_dimension):
await milvus_vec_index.add_chunks(sample_chunks, sample_embeddings)
query_emb = np.random.rand(embedding_dimension).astype(np.float32)
resp = await milvus_vec_index.query_vector(query_emb, k=2, score_threshold=0.0)
assert isinstance(resp, QueryChunksResponse)
assert len(resp.chunks) == 2
@pytest.mark.asyncio
async def test_chunk_id_conflict(milvus_vec_index, sample_chunks, embedding_dimension):
embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32) embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32)
await milvus_vec_index.add_chunks(sample_chunks, embeddings) await vector_index.add_chunks(sample_chunks, embeddings)
coll = Collection(milvus_vec_index.collection_name, using=MILVUS_ALIAS) resp = await vector_index.query_vector(
ids = coll.query(expr="id >= 0", output_fields=["id"], timeout=30) np.random.rand(embedding_dimension).astype(np.float32),
flat_ids = [i["id"] for i in ids] k=len(sample_chunks),
assert len(flat_ids) == len(set(flat_ids)) score_threshold=-1,
@pytest.mark.asyncio
async def test_initialize_with_milvus_client(milvus_vec_index, unique_kvstore_config):
kvstore = await kvstore_impl(unique_kvstore_config)
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
metadata={"test_key": "test_value"},
)
test_vector_db_data = vector_db.model_dump_json()
await kvstore.set(f"{VECTOR_DBS_PREFIX}test_db", test_vector_db_data)
tmp_milvus_vec_adapter = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(
db_path=milvus_vec_index.db_path,
kvstore=unique_kvstore_config,
),
inference_api=None,
files_api=None,
)
await tmp_milvus_vec_adapter.initialize()
vector_db = VectorDB(
identifier="test_db",
provider_id="test_provider",
embedding_model="test_model",
embedding_dimension=128,
)
test_vector_db_data = vector_db.model_dump_json()
await tmp_milvus_vec_adapter.kvstore.set(f"{VECTOR_DBS_PREFIX}/test_db", test_vector_db_data)
assert milvus_vec_index.client is not None
assert isinstance(milvus_vec_index.client, MilvusClient)
assert tmp_milvus_vec_adapter.cache is not None
# registering a vector won't update the cache or openai_vector_store collection name
assert (
tmp_milvus_vec_adapter.metadata_collection_name not in tmp_milvus_vec_adapter.cache
or tmp_milvus_vec_adapter.openai_vector_stores
) )
contents = [chunk.content for chunk in resp.chunks]
assert len(contents) == len(set(contents))
@pytest.mark.asyncio
async def test_persistence_across_adapter_restarts( async def test_initialize_adapter_with_existing_kvstore(vector_io_adapter):
tmp_path, milvus_vec_index, mock_inference_api, unique_kvstore_config key = f"{VECTOR_DBS_PREFIX}db1"
):
adapter1 = MilvusVectorIOAdapter(
config=MilvusVectorIOConfig(db_path=milvus_vec_index.db_path, kvstore=unique_kvstore_config),
inference_api=mock_inference_api,
files_api=None,
)
await adapter1.initialize()
dummy = VectorDB( dummy = VectorDB(
identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
) )
await adapter1.register_vector_db(dummy) await vector_io_adapter.kvstore.set(key=key, value=json.dumps(dummy.model_dump()))
await adapter1.shutdown()
await adapter1.initialize() await vector_io_adapter.initialize()
assert "foo_db" in adapter1.cache
await adapter1.shutdown()
@pytest.mark.asyncio async def test_persistence_across_adapter_restarts(vector_io_adapter):
async def test_register_and_unregister_vector_db(milvus_vec_adapter): await vector_io_adapter.initialize()
try: dummy = VectorDB(
connections.disconnect(MILVUS_ALIAS) identifier="foo_db", provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
except Exception as _: )
pass await vector_io_adapter.register_vector_db(dummy)
await vector_io_adapter.shutdown()
connections.connect(alias=MILVUS_ALIAS, uri=milvus_vec_adapter.config.db_path) await vector_io_adapter.initialize()
assert "foo_db" in vector_io_adapter.cache
await vector_io_adapter.shutdown()
async def test_register_and_unregister_vector_db(vector_io_adapter):
unique_id = f"foo_db_{np.random.randint(1e6)}" unique_id = f"foo_db_{np.random.randint(1e6)}"
dummy = VectorDB( dummy = VectorDB(
identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128 identifier=unique_id, provider_id="test_provider", embedding_model="test_model", embedding_dimension=128
) )
await milvus_vec_adapter.register_vector_db(dummy) await vector_io_adapter.register_vector_db(dummy)
assert dummy.identifier in milvus_vec_adapter.cache assert dummy.identifier in vector_io_adapter.cache
await vector_io_adapter.unregister_vector_db(dummy.identifier)
if dummy.identifier in milvus_vec_adapter.cache: assert dummy.identifier not in vector_io_adapter.cache
index = milvus_vec_adapter.cache[dummy.identifier].index
if hasattr(index, "client") and hasattr(index.client, "_using"):
index.client._using = MILVUS_ALIAS
await milvus_vec_adapter.unregister_vector_db(dummy.identifier)
assert dummy.identifier not in milvus_vec_adapter.cache
@pytest.mark.asyncio async def test_query_unregistered_raises(vector_io_adapter):
async def test_query_unregistered_raises(milvus_vec_adapter):
fake_emb = np.zeros(8, dtype=np.float32) fake_emb = np.zeros(8, dtype=np.float32)
with pytest.raises(AttributeError): with pytest.raises(ValueError):
await milvus_vec_adapter.query_chunks("no_such_db", fake_emb) await vector_io_adapter.query_chunks("no_such_db", fake_emb)
@pytest.mark.asyncio async def test_insert_chunks_calls_underlying_index(vector_io_adapter):
async def test_insert_chunks_calls_underlying_index(milvus_vec_adapter):
fake_index = AsyncMock() fake_index = AsyncMock()
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) vector_io_adapter.cache["db1"] = fake_index
chunks = ["chunk1", "chunk2"] chunks = ["chunk1", "chunk2"]
await milvus_vec_adapter.insert_chunks("db1", chunks) await vector_io_adapter.insert_chunks("db1", chunks)
fake_index.insert_chunks.assert_awaited_once_with(chunks) fake_index.insert_chunks.assert_awaited_once_with(chunks)
@pytest.mark.asyncio async def test_insert_chunks_missing_db_raises(vector_io_adapter):
async def test_insert_chunks_missing_db_raises(milvus_vec_adapter): vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await milvus_vec_adapter.insert_chunks("db_not_exist", []) await vector_io_adapter.insert_chunks("db_not_exist", [])
@pytest.mark.asyncio async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
async def test_query_chunks_calls_underlying_index_and_returns(milvus_vec_adapter):
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1]) expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected)) fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=fake_index) vector_io_adapter.cache["db1"] = fake_index
response = await milvus_vec_adapter.query_chunks("db1", "my_query", {"param": 1}) response = await vector_io_adapter.query_chunks("db1", "my_query", {"param": 1})
fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1}) fake_index.query_chunks.assert_awaited_once_with("my_query", {"param": 1})
assert response is expected assert response is expected
@pytest.mark.asyncio async def test_query_chunks_missing_db_raises(vector_io_adapter):
async def test_query_chunks_missing_db_raises(milvus_vec_adapter): vector_io_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
milvus_vec_adapter._get_and_cache_vector_db_index = AsyncMock(return_value=None)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await milvus_vec_adapter.query_chunks("db_missing", "q", None) await vector_io_adapter.query_chunks("db_missing", "q", None)
@pytest.mark.asyncio async def test_save_openai_vector_store(vector_io_adapter):
async def test_save_openai_vector_store(milvus_vec_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -260,14 +137,13 @@ async def test_save_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
assert openai_vector_store["id"] in milvus_vec_adapter.openai_vector_stores assert openai_vector_store["id"] in vector_io_adapter.openai_vector_stores
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio async def test_update_openai_vector_store(vector_io_adapter):
async def test_update_openai_vector_store(milvus_vec_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -277,14 +153,13 @@ async def test_update_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
openai_vector_store["description"] = "Updated description" openai_vector_store["description"] = "Updated description"
await milvus_vec_adapter._update_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._update_openai_vector_store(store_id, openai_vector_store)
assert milvus_vec_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store assert vector_io_adapter.openai_vector_stores[openai_vector_store["id"]] == openai_vector_store
@pytest.mark.asyncio async def test_delete_openai_vector_store(vector_io_adapter):
async def test_delete_openai_vector_store(milvus_vec_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -294,13 +169,12 @@ async def test_delete_openai_vector_store(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
await milvus_vec_adapter._delete_openai_vector_store_from_storage(store_id) await vector_io_adapter._delete_openai_vector_store_from_storage(store_id)
assert openai_vector_store["id"] not in milvus_vec_adapter.openai_vector_stores assert openai_vector_store["id"] not in vector_io_adapter.openai_vector_stores
@pytest.mark.asyncio async def test_load_openai_vector_stores(vector_io_adapter):
async def test_load_openai_vector_stores(milvus_vec_adapter):
store_id = "vs_1234" store_id = "vs_1234"
openai_vector_store = { openai_vector_store = {
"id": store_id, "id": store_id,
@ -310,13 +184,12 @@ async def test_load_openai_vector_stores(milvus_vec_adapter):
"embedding_model": "test_model", "embedding_model": "test_model",
} }
await milvus_vec_adapter._save_openai_vector_store(store_id, openai_vector_store) await vector_io_adapter._save_openai_vector_store(store_id, openai_vector_store)
loaded_stores = await milvus_vec_adapter._load_openai_vector_stores() loaded_stores = await vector_io_adapter._load_openai_vector_stores()
assert loaded_stores[store_id] == openai_vector_store assert loaded_stores[store_id] == openai_vector_store
@pytest.mark.asyncio async def test_save_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -334,11 +207,10 @@ async def test_save_openai_vector_store_file(milvus_vec_adapter, tmp_path_factor
] ]
# validating we don't raise an exception # validating we don't raise an exception
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
@pytest.mark.asyncio async def test_update_openai_vector_store_file(vector_io_adapter, tmp_path_factory):
async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -355,24 +227,23 @@ async def test_update_openai_vector_store_file(milvus_vec_adapter, tmp_path_fact
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
] ]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
updated_file_info = file_info.copy() updated_file_info = file_info.copy()
updated_file_info["filename"] = "updated_test_file.txt" updated_file_info["filename"] = "updated_test_file.txt"
await milvus_vec_adapter._update_openai_vector_store_file( await vector_io_adapter._update_openai_vector_store_file(
store_id, store_id,
file_id, file_id,
updated_file_info, updated_file_info,
) )
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file(store_id, file_id) loaded_contents = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_contents == updated_file_info assert loaded_contents == updated_file_info
assert loaded_contents != file_info assert loaded_contents != file_info
@pytest.mark.asyncio async def test_load_openai_vector_store_file_contents(vector_io_adapter, tmp_path_factory):
async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -389,14 +260,13 @@ async def test_load_openai_vector_store_file_contents(milvus_vec_adapter, tmp_pa
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
] ]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == file_contents assert loaded_contents == file_contents
@pytest.mark.asyncio async def test_delete_openai_vector_store_file_from_storage(vector_io_adapter, tmp_path_factory):
async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter, tmp_path_factory):
store_id = "vs_1234" store_id = "vs_1234"
file_id = "file_1234" file_id = "file_1234"
@ -413,8 +283,10 @@ async def test_delete_openai_vector_store_file_from_storage(milvus_vec_adapter,
{"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}} {"content": "Test content", "chunk_metadata": {"chunk_id": "chunk_001"}, "metadata": {"file_id": file_id}}
] ]
await milvus_vec_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents) await vector_io_adapter._save_openai_vector_store_file(store_id, file_id, file_info, file_contents)
await milvus_vec_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id) await vector_io_adapter._delete_openai_vector_store_file_from_storage(store_id, file_id)
loaded_contents = await milvus_vec_adapter._load_openai_vector_store_file_contents(store_id, file_id) loaded_file_info = await vector_io_adapter._load_openai_vector_store_file(store_id, file_id)
assert loaded_file_info == {}
loaded_contents = await vector_io_adapter._load_openai_vector_store_file_contents(store_id, file_id)
assert loaded_contents == [] assert loaded_contents == []

Some files were not shown because too many files have changed in this diff Show more