chore: Updating documentation and adding exception handling for Vector Stores in RAG Tool and updating inference to use openai and updating memory implementation to use existing libraries

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-07 13:52:39 -04:00
parent 28696c3f30
commit ff0bd414b1
27 changed files with 926 additions and 403 deletions

View file

@ -5,21 +5,22 @@ inputs:
stack-config: stack-config:
description: 'Stack configuration to use' description: 'Stack configuration to use'
required: true required: true
provider: setup:
description: 'Provider to use for tests' description: 'Setup to use for tests (e.g., ollama, gpt, vllm)'
required: true required: false
default: ''
inference-mode: inference-mode:
description: 'Inference mode (record or replay)' description: 'Inference mode (record or replay)'
required: true required: true
test-suite: suite:
description: 'Test suite to use: base, responses, vision, etc.' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: '' default: ''
test-subdirs: subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite' description: 'Comma-separated list of test subdirectories to run; overrides suite'
required: false required: false
default: '' default: ''
test-pattern: pattern:
description: 'Regex pattern to pass to pytest -k' description: 'Regex pattern to pass to pytest -k'
required: false required: false
default: '' default: ''
@ -37,14 +38,23 @@ runs:
- name: Run Integration Tests - name: Run Integration Tests
shell: bash shell: bash
run: | run: |
uv run --no-sync ./scripts/integration-tests.sh \ SCRIPT_ARGS="--stack-config ${{ inputs.stack-config }} --inference-mode ${{ inputs.inference-mode }}"
--stack-config '${{ inputs.stack-config }}' \
--provider '${{ inputs.provider }}' \ # Add optional arguments only if they are provided
--test-subdirs '${{ inputs.test-subdirs }}' \ if [ -n '${{ inputs.setup }}' ]; then
--test-pattern '${{ inputs.test-pattern }}' \ SCRIPT_ARGS="$SCRIPT_ARGS --setup ${{ inputs.setup }}"
--inference-mode '${{ inputs.inference-mode }}' \ fi
--test-suite '${{ inputs.test-suite }}' \ if [ -n '${{ inputs.suite }}' ]; then
| tee pytest-${{ inputs.inference-mode }}.log SCRIPT_ARGS="$SCRIPT_ARGS --suite ${{ inputs.suite }}"
fi
if [ -n '${{ inputs.subdirs }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --subdirs ${{ inputs.subdirs }}"
fi
if [ -n '${{ inputs.pattern }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --pattern ${{ inputs.pattern }}"
fi
uv run --no-sync ./scripts/integration-tests.sh $SCRIPT_ARGS | tee pytest-${{ inputs.inference-mode }}.log
- name: Commit and push recordings - name: Commit and push recordings
@ -58,7 +68,7 @@ runs:
echo "New recordings detected, committing and pushing" echo "New recordings detected, committing and pushing"
git add tests/integration/recordings/ git add tests/integration/recordings/
git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})" git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})"
git fetch origin ${{ github.ref_name }} git fetch origin ${{ github.ref_name }}
git rebase origin/${{ github.ref_name }} git rebase origin/${{ github.ref_name }}
echo "Rebased successfully" echo "Rebased successfully"

View file

@ -1,7 +1,7 @@
name: Setup Ollama name: Setup Ollama
description: Start Ollama description: Start Ollama
inputs: inputs:
test-suite: suite:
description: 'Test suite to use: base, responses, vision, etc.' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: '' default: ''
@ -11,7 +11,7 @@ runs:
- name: Start Ollama - name: Start Ollama
shell: bash shell: bash
run: | run: |
if [ "${{ inputs.test-suite }}" == "vision" ]; then if [ "${{ inputs.suite }}" == "vision" ]; then
image="ollama-with-vision-model" image="ollama-with-vision-model"
else else
image="ollama-with-models" image="ollama-with-models"

View file

@ -8,11 +8,11 @@ inputs:
client-version: client-version:
description: 'Client version (latest or published)' description: 'Client version (latest or published)'
required: true required: true
provider: setup:
description: 'Provider to setup (ollama or vllm)' description: 'Setup to configure (ollama, vllm, gpt, etc.)'
required: true required: false
default: 'ollama' default: 'ollama'
test-suite: suite:
description: 'Test suite to use: base, responses, vision, etc.' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: '' default: ''
@ -30,13 +30,13 @@ runs:
client-version: ${{ inputs.client-version }} client-version: ${{ inputs.client-version }}
- name: Setup ollama - name: Setup ollama
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} if: ${{ (inputs.setup == 'ollama' || inputs.setup == 'ollama-vision') && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-ollama uses: ./.github/actions/setup-ollama
with: with:
test-suite: ${{ inputs.test-suite }} suite: ${{ inputs.suite }}
- name: Setup vllm - name: Setup vllm
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} if: ${{ inputs.setup == 'vllm' && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-vllm uses: ./.github/actions/setup-vllm
- name: Build Llama Stack - name: Build Llama Stack

View file

@ -28,8 +28,8 @@ on:
description: 'Test against both the latest and published versions' description: 'Test against both the latest and published versions'
type: boolean type: boolean
default: false default: false
test-provider: test-setup:
description: 'Test against a specific provider' description: 'Test against a specific setup'
type: string type: string
default: 'ollama' default: 'ollama'
@ -42,18 +42,18 @@ jobs:
run-replay-mode-tests: run-replay-mode-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }} name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.setup, matrix.python-version, matrix.client-version, matrix.suite) }}
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
client-type: [library, server] client-type: [library, server]
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama) # Use vllm on weekly schedule, otherwise use test-setup input (defaults to ollama)
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }} setup: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-setup || 'ollama')) }}
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
test-suite: [base, vision] suite: [base, vision]
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -64,14 +64,14 @@ jobs:
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }} client-version: ${{ matrix.client-version }}
provider: ${{ matrix.provider }} setup: ${{ matrix.setup }}
test-suite: ${{ matrix.test-suite }} suite: ${{ matrix.suite }}
inference-mode: 'replay' inference-mode: 'replay'
- name: Run tests - name: Run tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }} setup: ${{ matrix.setup }}
inference-mode: 'replay' inference-mode: 'replay'
test-suite: ${{ matrix.test-suite }} suite: ${{ matrix.suite }}

View file

@ -48,7 +48,6 @@ jobs:
working-directory: llama_stack/ui working-directory: llama_stack/ui
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1 - uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env: env:
SKIP: no-commit-to-branch SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github RUFF_OUTPUT_FORMAT: github

View file

@ -10,19 +10,19 @@ run-name: Run the integration test suite from tests/integration
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
test-provider: test-setup:
description: 'Test against a specific provider' description: 'Test against a specific setup'
type: string type: string
default: 'ollama' default: 'ollama'
test-suite: suite:
description: 'Test suite to use: base, responses, vision, etc.' description: 'Test suite to use: base, responses, vision, etc.'
type: string type: string
default: '' default: ''
test-subdirs: subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite' description: 'Comma-separated list of test subdirectories to run; overrides suite'
type: string type: string
default: '' default: ''
test-pattern: pattern:
description: 'Regex pattern to pass to pytest -k' description: 'Regex pattern to pass to pytest -k'
type: string type: string
default: '' default: ''
@ -39,10 +39,10 @@ jobs:
run: | run: |
echo "::group::Workflow Inputs" echo "::group::Workflow Inputs"
echo "branch: ${{ github.ref_name }}" echo "branch: ${{ github.ref_name }}"
echo "test-provider: ${{ inputs.test-provider }}" echo "test-setup: ${{ inputs.test-setup }}"
echo "test-suite: ${{ inputs.test-suite }}" echo "suite: ${{ inputs.suite }}"
echo "test-subdirs: ${{ inputs.test-subdirs }}" echo "subdirs: ${{ inputs.subdirs }}"
echo "test-pattern: ${{ inputs.test-pattern }}" echo "pattern: ${{ inputs.pattern }}"
echo "::endgroup::" echo "::endgroup::"
- name: Checkout repository - name: Checkout repository
@ -55,16 +55,16 @@ jobs:
with: with:
python-version: "3.12" # Use single Python version for recording python-version: "3.12" # Use single Python version for recording
client-version: "latest" client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }} setup: ${{ inputs.test-setup || 'ollama' }}
test-suite: ${{ inputs.test-suite }} suite: ${{ inputs.suite }}
inference-mode: 'record' inference-mode: 'record'
- name: Run and record tests - name: Run and record tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }} setup: ${{ inputs.test-setup || 'ollama' }}
inference-mode: 'record' inference-mode: 'record'
test-suite: ${{ inputs.test-suite }} suite: ${{ inputs.suite }}
test-subdirs: ${{ inputs.test-subdirs }} subdirs: ${{ inputs.subdirs }}
test-pattern: ${{ inputs.test-pattern }} pattern: ${{ inputs.pattern }}

View file

@ -93,10 +93,31 @@ chunks_response = client.vector_io.query(
### Using the RAG Tool ### Using the RAG Tool
> **⚠️ DEPRECATION NOTICE**: The RAG Tool is being deprecated in favor of directly using the OpenAI-compatible Search
> API. We recommend migrating to the OpenAI APIs for better compatibility and future support.
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
[appendix](#more-ragdocument-examples). [appendix](#more-ragdocument-examples).
#### OpenAI API Integration & Migration
The RAG tool has been updated to use OpenAI-compatible APIs. This provides several benefits:
- **Files API Integration**: Documents are now uploaded using OpenAI's file upload endpoints
- **Vector Stores API**: Vector storage operations use OpenAI's vector store format with configurable chunking strategies
- **Error Resilience:** When processing multiple documents, individual failures are logged but don't crash the operation. Failed documents are skipped while successful ones continue processing.
**Migration Path:**
We recommend migrating to the OpenAI-compatible Search API for:
1. **Better OpenAI Ecosystem Integration**: Direct compatibility with OpenAI tools and workflows including the Responses API
2**Future-Proof**: Continued support and feature development
3**Full OpenAI Compatibility**: Vector Stores, Files, and Search APIs are fully compatible with OpenAI's Responses API
The OpenAI APIs are used under the hood, so you can continue to use your existing RAG Tool code with minimal changes.
However, we recommend updating your code to use the new OpenAI-compatible APIs for better long-term support. If any
documents fail to process, they will be logged in the response but will not cause the entire operation to fail.
```python ```python
from llama_stack_client import RAGDocument from llama_stack_client import RAGDocument

View file

@ -6,6 +6,7 @@ data:
apis: apis:
- agents - agents
- inference - inference
- files
- safety - safety
- telemetry - telemetry
- tool_runtime - tool_runtime
@ -19,13 +20,6 @@ data:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake} api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true} tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
@ -41,6 +35,14 @@ data:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard
@ -111,9 +113,6 @@ data:
- model_id: ${env.INFERENCE_MODEL} - model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
model_type: llm model_type: llm
- model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety
model_type: llm
shields: shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} - shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: [] vector_dbs: []

View file

@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- inference - inference
- files
- safety - safety
- telemetry - telemetry
- tool_runtime - tool_runtime
@ -31,6 +32,14 @@ providers:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -1,137 +1,55 @@
apiVersion: v1 apiVersion: v1
data: data:
stack_run_config.yaml: | stack_run_config.yaml: "version: '2'\nimage_name: kubernetes-demo\napis:\n- agents\n-
version: '2' inference\n- files\n- safety\n- telemetry\n- tool_runtime\n- vector_io\nproviders:\n
image_name: kubernetes-demo \ inference:\n - provider_id: vllm-inference\n provider_type: remote::vllm\n
apis: \ config:\n url: ${env.VLLM_URL:=http://localhost:8000/v1}\n max_tokens:
- agents ${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n tls_verify:
- inference ${env.VLLM_TLS_VERIFY:=true}\n - provider_id: vllm-safety\n provider_type:
- safety remote::vllm\n config:\n url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}\n
- telemetry \ max_tokens: ${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n
- tool_runtime \ tls_verify: ${env.VLLM_TLS_VERIFY:=true}\n - provider_id: sentence-transformers\n
- vector_io \ provider_type: inline::sentence-transformers\n config: {}\n vector_io:\n
providers: \ - provider_id: ${env.ENABLE_CHROMADB:+chromadb}\n provider_type: remote::chromadb\n
inference: \ config:\n url: ${env.CHROMADB_URL:=}\n kvstore:\n type: postgres\n
- provider_id: vllm-inference \ host: ${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n
provider_type: remote::vllm \ db: ${env.POSTGRES_DB:=llamastack}\n user: ${env.POSTGRES_USER:=llamastack}\n
config: \ password: ${env.POSTGRES_PASSWORD:=llamastack}\n files:\n - provider_id:
url: ${env.VLLM_URL:=http://localhost:8000/v1} meta-reference-files\n provider_type: inline::localfs\n config:\n storage_dir:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}\n metadata_store:\n
api_token: ${env.VLLM_API_TOKEN:=fake} \ type: sqlite\n db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
tls_verify: ${env.VLLM_TLS_VERIFY:=true} \ \n safety:\n - provider_id: llama-guard\n provider_type: inline::llama-guard\n
- provider_id: vllm-safety \ config:\n excluded_categories: []\n agents:\n - provider_id: meta-reference\n
provider_type: remote::vllm \ provider_type: inline::meta-reference\n config:\n persistence_store:\n
config: \ type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n port:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1} ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096} ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
api_token: ${env.VLLM_API_TOKEN:=fake} \ responses_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
tls_verify: ${env.VLLM_TLS_VERIFY:=true} \ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
- provider_id: sentence-transformers \ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
provider_type: inline::sentence-transformers \ telemetry:\n - provider_id: meta-reference\n provider_type: inline::meta-reference\n
config: {} \ config:\n service_name: \"${env.OTEL_SERVICE_NAME:=\\u200B}\"\n sinks:
vector_io: ${env.TELEMETRY_SINKS:=console}\n tool_runtime:\n - provider_id: brave-search\n
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} \ provider_type: remote::brave-search\n config:\n api_key: ${env.BRAVE_SEARCH_API_KEY:+}\n
provider_type: remote::chromadb \ max_results: 3\n - provider_id: tavily-search\n provider_type: remote::tavily-search\n
config: \ config:\n api_key: ${env.TAVILY_SEARCH_API_KEY:+}\n max_results:
url: ${env.CHROMADB_URL:=} 3\n - provider_id: rag-runtime\n provider_type: inline::rag-runtime\n config:
kvstore: {}\n - provider_id: model-context-protocol\n provider_type: remote::model-context-protocol\n
type: postgres \ config: {}\nmetadata_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
host: ${env.POSTGRES_HOST:=localhost} \ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
port: ${env.POSTGRES_PORT:=5432} ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
db: ${env.POSTGRES_DB:=llamastack} \ table_name: llamastack_kvstore\ninference_store:\n type: postgres\n host:
user: ${env.POSTGRES_USER:=llamastack} ${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
password: ${env.POSTGRES_PASSWORD:=llamastack} \ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\nmodels:\n-
safety: metadata:\n embedding_dimension: 384\n model_id: all-MiniLM-L6-v2\n provider_id:
- provider_id: llama-guard sentence-transformers\n model_type: embedding\n- metadata: {}\n model_id: ${env.INFERENCE_MODEL}\n
provider_type: inline::llama-guard \ provider_id: vllm-inference\n model_type: llm\n- metadata: {}\n model_id:
config: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\n provider_id: vllm-safety\n
excluded_categories: [] \ model_type: llm\nshields:\n- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\nvector_dbs:
agents: []\ndatasets: []\nscoring_fns: []\nbenchmarks: []\ntool_groups:\n- toolgroup_id:
- provider_id: meta-reference builtin::websearch\n provider_id: tavily-search\n- toolgroup_id: builtin::rag\n
provider_type: inline::meta-reference \ provider_id: rag-runtime\nserver:\n port: 8321\n auth:\n provider_config:\n
config: \ type: github_token\n"
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
provider_id: vllm-safety
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
auth:
provider_config:
type: github_token
kind: ConfigMap kind: ConfigMap
metadata: metadata:
creationTimestamp: null creationTimestamp: null

View file

@ -3,6 +3,7 @@ image_name: kubernetes-demo
apis: apis:
- agents - agents
- inference - inference
- files
- safety - safety
- telemetry - telemetry
- tool_runtime - tool_runtime
@ -38,6 +39,14 @@ providers:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -45,6 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.exec import formulate_run_args, run_command from llama_stack.core.utils.exec import formulate_run_args, run_command
from llama_stack.core.utils.image_types import LlamaStackImageType from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions" DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
@ -294,6 +295,12 @@ def _generate_run_config(
if build_config.external_providers_dir if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR, else EXTERNAL_PROVIDERS_DIR,
) )
if not run_config.inference_store:
run_config.inference_store = SqliteSqlStoreConfig(
**SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
)
)
# build providers dict # build providers dict
provider_registry = get_provider_registry(build_config) provider_registry = get_provider_registry(build_config)
for api in apis: for api in apis:

View file

@ -10,7 +10,6 @@ import json
import logging # allow-direct-logging import logging # allow-direct-logging
import os import os
import sys import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum from enum import Enum
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.provider_data = provider_data self.provider_data = provider_data
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()

View file

@ -8,7 +8,7 @@
from jinja2 import Template from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import ( from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig, DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig,
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
messages = [interleaved_content_as_str(content)] messages = [interleaved_content_as_str(content)]
template = Template(config.template) template = Template(config.template)
content = template.render({"messages": messages}) rendered_content: str = template.render({"messages": messages})
model = config.model model = config.model
message = UserMessage(content=content) message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.chat_completion( response = await inference_api.openai_chat_completion(
model_id=model, model=model,
messages=[message], messages=[message],
stream=False, stream=False,
) )
query = response.completion_message.content query = response.choices[0].message.content
return query return query

View file

@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import parse_data_url
content_from_doc,
parse_data_url,
)
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query from .context_retriever import generate_rag_query
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime): class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__( def __init__(
self, self,
@ -95,46 +133,52 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return return
for doc in documents: for doc in documents:
if isinstance(doc.content, URL): try:
if doc.content.uri.startswith("data:"): try:
parts = parse_data_url(doc.content.uri) file_data, mime_type = await raw_data_from_doc(doc)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode() except Exception as e:
mime_type = parts["mimetype"] log.error(f"Failed to extract content from document {doc.document_id}: {e}")
else: continue
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
file_extension = mimetypes.guess_extension(mime_type) or ".txt" file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data) file_obj = io.BytesIO(file_data)
file_obj.name = filename file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename) upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file( try:
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS created_file = await self.files_api.openai_upload_file(
) file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic( chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig( static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens, max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4, chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
) )
)
await self.vector_io_api.openai_attach_file_to_vector_store( try:
vector_store_id=vector_db_id, await self.vector_io_api.openai_attach_file_to_vector_store(
file_id=created_file.id, vector_store_id=vector_db_id,
attributes=doc.metadata, file_id=created_file.id,
chunking_strategy=chunking_strategy, attributes=doc.metadata,
) chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query( async def query(
self, self,
@ -167,8 +211,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
for vector_db_id in vector_db_ids for vector_db_id in vector_db_ids
] ]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks) results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores] chunks = []
scores = []
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id
chunks.append(chunk)
scores.append(score)
if not chunks: if not chunks:
return RAGQueryResult(content=None) return RAGQueryResult(content=None)
@ -203,6 +257,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
metadata_keys_to_exclude_from_context = [ metadata_keys_to_exclude_from_context = [
"token_count", "token_count",
"metadata_token_count", "metadata_token_count",
"vector_db_id",
] ]
metadata_for_context = {} metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context: for k in chunk_metadata_keys_to_include_from_context:
@ -227,6 +282,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]], "chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)], "scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
}, },
) )
@ -262,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
if query_config: if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else: else:
# handle someone passing an empty dict
query_config = RAGQueryConfig() query_config = RAGQueryConfig()
query = kwargs["query"] query = kwargs["query"]
@ -273,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
) )
return ToolInvocationResult( return ToolInvocationResult(
content=result.content, content=result.content or [],
metadata=result.metadata, metadata=result.metadata,
) )

View file

@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="vertexai", adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"], pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
module="llama_stack.providers.remote.inference.vertexai", module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",

View file

@ -6,16 +6,20 @@
from typing import Any from typing import Any
import google.auth.transport.requests
from google.auth import default
from llama_stack.apis.inference import ChatCompletionRequest from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import ( from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin, LiteLLMOpenAIMixin,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig from .config import VertexAIConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin): class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None: def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
@ -27,9 +31,30 @@ class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
self.config = config self.config = config
def get_api_key(self) -> str: def get_api_key(self) -> str:
# Vertex AI doesn't use API keys, it uses Application Default Credentials """
# Return empty string to let litellm handle authentication via ADC Get an access token for Vertex AI using Application Default Credentials.
return ""
Vertex AI uses ADC instead of API keys. This method obtains an access token
from the default credentials and returns it for use with the OpenAI-compatible client.
"""
try:
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(google.auth.transport.requests.Request())
return str(credentials.token)
except Exception:
# If we can't get credentials, return empty string to let LiteLLM handle it
# This allows the LiteLLM mixin to work with ADC directly
return ""
def get_base_url(self) -> str:
"""
Get the Vertex AI OpenAI-compatible API base URL.
Returns the Vertex AI OpenAI-compatible endpoint URL.
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]: async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent # Get base parameters from parent

71
scripts/get_setup_env.py Executable file
View file

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Small helper script to extract environment variables from a test setup.
Used by integration-tests.sh to set environment variables before starting the server.
"""
import argparse
import sys
from tests.integration.suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
def get_setup_env_vars(setup_name, suite_name=None):
"""
Get environment variables for a setup, with optional suite default fallback.
Args:
setup_name: Name of the setup (e.g., 'ollama', 'gpt')
suite_name: Optional suite name to get default setup if setup_name is None
Returns:
Dictionary of environment variables
"""
# If no setup specified, try to get default from suite
if not setup_name and suite_name:
suite = SUITE_DEFINITIONS.get(suite_name)
if suite and suite.default_setup:
setup_name = suite.default_setup
if not setup_name:
return {}
setup = SETUP_DEFINITIONS.get(setup_name)
if not setup:
print(
f"Error: Unknown setup '{setup_name}'. Available: {', '.join(sorted(SETUP_DEFINITIONS.keys()))}",
file=sys.stderr,
)
sys.exit(1)
return setup.env
def main():
parser = argparse.ArgumentParser(description="Extract environment variables from a test setup")
parser.add_argument("--setup", help="Setup name (e.g., ollama, gpt)")
parser.add_argument("--suite", help="Suite name to get default setup from if --setup not provided")
parser.add_argument("--format", choices=["bash", "json"], default="bash", help="Output format (default: bash)")
args = parser.parse_args()
env_vars = get_setup_env_vars(args.setup, args.suite)
if args.format == "bash":
# Output as bash export statements
for key, value in env_vars.items():
print(f"export {key}='{value}'")
elif args.format == "json":
import json
print(json.dumps(env_vars))
if __name__ == "__main__":
main()

View file

@ -14,7 +14,7 @@ set -euo pipefail
# Default values # Default values
BRANCH="" BRANCH=""
TEST_SUBDIRS="" TEST_SUBDIRS=""
TEST_PROVIDER="ollama" TEST_SETUP="ollama"
TEST_SUITE="base" TEST_SUITE="base"
TEST_PATTERN="" TEST_PATTERN=""
@ -27,24 +27,24 @@ Trigger the integration test recording workflow remotely. This way you do not ne
OPTIONS: OPTIONS:
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch) -b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama) -t, --suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
-t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base) -p, --setup SETUP Test setup to use: vllm, ollama, gpt, etc. (default: ollama)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite) -s, --subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k -k, --pattern PATTERN Regex pattern to pass to pytest -k
-h, --help Show this help message -h, --help Show this help message
EXAMPLES: EXAMPLES:
# Record tests for current branch with agents subdirectory # Record tests for current branch with agents subdirectory
$0 --test-subdirs "agents" $0 --subdirs "agents"
# Record tests for specific branch with vision tests # Record tests for specific branch with vision tests
$0 -b my-feature-branch --test-suite vision $0 -b my-feature-branch --suite vision
# Record multiple test subdirectories with specific provider # Record multiple test subdirectories with specific setup
$0 --test-subdirs "agents,inference" --test-provider vllm $0 --subdirs "agents,inference" --setup vllm
# Record tests matching a specific pattern # Record tests matching a specific pattern
$0 --test-subdirs "inference" --test-pattern "test_streaming" $0 --subdirs "inference" --pattern "test_streaming"
EOF EOF
} }
@ -63,19 +63,19 @@ while [[ $# -gt 0 ]]; do
BRANCH="$2" BRANCH="$2"
shift 2 shift 2
;; ;;
-s|--test-subdirs) -s|--subdirs)
TEST_SUBDIRS="$2" TEST_SUBDIRS="$2"
shift 2 shift 2
;; ;;
-p|--test-provider) -p|--setup)
TEST_PROVIDER="$2" TEST_SETUP="$2"
shift 2 shift 2
;; ;;
-t|--test-suite) -t|--suite)
TEST_SUITE="$2" TEST_SUITE="$2"
shift 2 shift 2
;; ;;
-k|--test-pattern) -k|--pattern)
TEST_PATTERN="$2" TEST_PATTERN="$2"
shift 2 shift 2
;; ;;
@ -93,21 +93,16 @@ done
# Validate required parameters # Validate required parameters
if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
echo "Error: --test-subdirs or --test-suite is required" echo "Error: --subdirs or --suite is required"
echo "Please specify which test subdirectories to run or test suite to use, e.g.:" echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
echo " $0 --test-subdirs \"agents,inference\"" echo " $0 --subdirs \"agents,inference\""
echo " $0 --test-suite vision" echo " $0 --suite vision"
echo "" echo ""
exit 1 exit 1
fi fi
# Validate test provider # Validate test setup (optional - setups are validated by the workflow itself)
if [[ "$TEST_PROVIDER" != "vllm" && "$TEST_PROVIDER" != "ollama" ]]; then # Common setups: ollama, vllm, gpt, etc.
echo "❌ Error: Invalid test provider '$TEST_PROVIDER'"
echo " Supported providers: vllm, ollama"
echo " Example: $0 --test-subdirs \"agents\" --test-provider vllm"
exit 1
fi
# Check if required tools are installed # Check if required tools are installed
if ! command -v gh &> /dev/null; then if ! command -v gh &> /dev/null; then
@ -237,7 +232,7 @@ fi
# Build the workflow dispatch command # Build the workflow dispatch command
echo "Triggering integration test recording workflow..." echo "Triggering integration test recording workflow..."
echo "Branch: $BRANCH" echo "Branch: $BRANCH"
echo "Test provider: $TEST_PROVIDER" echo "Test setup: $TEST_SETUP"
echo "Test subdirs: $TEST_SUBDIRS" echo "Test subdirs: $TEST_SUBDIRS"
echo "Test suite: $TEST_SUITE" echo "Test suite: $TEST_SUITE"
echo "Test pattern: ${TEST_PATTERN:-"(none)"}" echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
@ -245,16 +240,16 @@ echo ""
# Prepare inputs for gh workflow run # Prepare inputs for gh workflow run
if [[ -n "$TEST_SUBDIRS" ]]; then if [[ -n "$TEST_SUBDIRS" ]]; then
INPUTS="-f test-subdirs='$TEST_SUBDIRS'" INPUTS="-f subdirs='$TEST_SUBDIRS'"
fi fi
if [[ -n "$TEST_PROVIDER" ]]; then if [[ -n "$TEST_SETUP" ]]; then
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'" INPUTS="$INPUTS -f test-setup='$TEST_SETUP'"
fi fi
if [[ -n "$TEST_SUITE" ]]; then if [[ -n "$TEST_SUITE" ]]; then
INPUTS="$INPUTS -f test-suite='$TEST_SUITE'" INPUTS="$INPUTS -f suite='$TEST_SUITE'"
fi fi
if [[ -n "$TEST_PATTERN" ]]; then if [[ -n "$TEST_PATTERN" ]]; then
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'" INPUTS="$INPUTS -f pattern='$TEST_PATTERN'"
fi fi
# Run the workflow # Run the workflow

View file

@ -13,10 +13,10 @@ set -euo pipefail
# Default values # Default values
STACK_CONFIG="" STACK_CONFIG=""
PROVIDER="" TEST_SUITE="base"
TEST_SETUP=""
TEST_SUBDIRS="" TEST_SUBDIRS=""
TEST_PATTERN="" TEST_PATTERN=""
TEST_SUITE="base"
INFERENCE_MODE="replay" INFERENCE_MODE="replay"
EXTRA_PARAMS="" EXTRA_PARAMS=""
@ -27,29 +27,30 @@ Usage: $0 [OPTIONS]
Options: Options:
--stack-config STRING Stack configuration to use (required) --stack-config STRING Stack configuration to use (required)
--provider STRING Provider to use (ollama, vllm, etc.) (required) --suite STRING Test suite to run (default: 'base')
--test-suite STRING Comma-separated list of test suites to run (default: 'base') --setup STRING Test setup (models, env) to use (e.g., 'ollama', 'ollama-vision', 'gpt', 'vllm')
--inference-mode STRING Inference mode: record or replay (default: replay) --inference-mode STRING Inference mode: record or replay (default: replay)
--test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite) --subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--test-pattern STRING Regex pattern to pass to pytest -k --pattern STRING Regex pattern to pass to pytest -k
--help Show this help message --help Show this help message
Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options. Suites are defined in tests/integration/suites.py and define which tests to run.
Setups are defined in tests/integration/setups.py and provide global configuration (models, env).
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite. You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
Examples: Examples:
# Basic inference tests with ollama # Basic inference tests with ollama
$0 --stack-config server:ci-tests --provider ollama $0 --stack-config server:ci-tests --suite base --setup ollama
# Multiple test directories with vllm # Multiple test directories with vllm
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents' $0 --stack-config server:ci-tests --subdirs 'inference,agents' --setup vllm
# Vision tests with ollama # Vision tests with ollama
$0 --stack-config server:ci-tests --provider ollama --test-suite vision $0 --stack-config server:ci-tests --suite vision # default setup for this suite is ollama-vision
# Record mode for updating test recordings # Record mode for updating test recordings
$0 --stack-config server:ci-tests --provider ollama --inference-mode record $0 --stack-config server:ci-tests --suite base --inference-mode record
EOF EOF
} }
@ -60,15 +61,15 @@ while [[ $# -gt 0 ]]; do
STACK_CONFIG="$2" STACK_CONFIG="$2"
shift 2 shift 2
;; ;;
--provider) --setup)
PROVIDER="$2" TEST_SETUP="$2"
shift 2 shift 2
;; ;;
--test-subdirs) --subdirs)
TEST_SUBDIRS="$2" TEST_SUBDIRS="$2"
shift 2 shift 2
;; ;;
--test-suite) --suite)
TEST_SUITE="$2" TEST_SUITE="$2"
shift 2 shift 2
;; ;;
@ -76,7 +77,7 @@ while [[ $# -gt 0 ]]; do
INFERENCE_MODE="$2" INFERENCE_MODE="$2"
shift 2 shift 2
;; ;;
--test-pattern) --pattern)
TEST_PATTERN="$2" TEST_PATTERN="$2"
shift 2 shift 2
;; ;;
@ -96,11 +97,13 @@ done
# Validate required parameters # Validate required parameters
if [[ -z "$STACK_CONFIG" ]]; then if [[ -z "$STACK_CONFIG" ]]; then
echo "Error: --stack-config is required" echo "Error: --stack-config is required"
usage
exit 1 exit 1
fi fi
if [[ -z "$PROVIDER" ]]; then if [[ -z "$TEST_SETUP" && -n "$TEST_SUBDIRS" ]]; then
echo "Error: --provider is required" echo "Error: --test-setup is required when --test-subdirs is provided"
usage
exit 1 exit 1
fi fi
@ -111,7 +114,7 @@ fi
echo "=== Llama Stack Integration Test Runner ===" echo "=== Llama Stack Integration Test Runner ==="
echo "Stack Config: $STACK_CONFIG" echo "Stack Config: $STACK_CONFIG"
echo "Provider: $PROVIDER" echo "Setup: $TEST_SETUP"
echo "Inference Mode: $INFERENCE_MODE" echo "Inference Mode: $INFERENCE_MODE"
echo "Test Suite: $TEST_SUITE" echo "Test Suite: $TEST_SUITE"
echo "Test Subdirs: $TEST_SUBDIRS" echo "Test Subdirs: $TEST_SUBDIRS"
@ -129,21 +132,25 @@ echo ""
# Set environment variables # Set environment variables
export LLAMA_STACK_CLIENT_TIMEOUT=300 export LLAMA_STACK_CLIENT_TIMEOUT=300
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
# Configure provider-specific settings
if [[ "$PROVIDER" == "ollama" ]]; then
export OLLAMA_URL="http://0.0.0.0:11434"
export TEXT_MODEL="ollama/llama3.2:3b-instruct-fp16"
export SAFETY_MODEL="ollama/llama-guard3:1b"
EXTRA_PARAMS="--safety-shield=llama-guard"
else
export VLLM_URL="http://localhost:8000/v1"
export TEXT_MODEL="vllm/meta-llama/Llama-3.2-1B-Instruct"
EXTRA_PARAMS=""
fi
THIS_DIR=$(dirname "$0") THIS_DIR=$(dirname "$0")
if [[ -n "$TEST_SETUP" ]]; then
EXTRA_PARAMS="--setup=$TEST_SETUP"
fi
# Apply setup-specific environment variables (needed for server startup and tests)
echo "=== Applying Setup Environment Variables ==="
# the server needs this
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
SETUP_ENV=$(PYTHONPATH=$THIS_DIR/.. python "$THIS_DIR/get_setup_env.py" --suite "$TEST_SUITE" --setup "$TEST_SETUP" --format bash)
echo "Setting up environment variables:"
echo "$SETUP_ENV"
eval "$SETUP_ENV"
echo ""
ROOT_DIR="$THIS_DIR/.." ROOT_DIR="$THIS_DIR/.."
cd $ROOT_DIR cd $ROOT_DIR
@ -162,6 +169,18 @@ fi
# Start Llama Stack Server if needed # Start Llama Stack Server if needed
if [[ "$STACK_CONFIG" == *"server:"* ]]; then if [[ "$STACK_CONFIG" == *"server:"* ]]; then
stop_server() {
echo "Stopping Llama Stack Server..."
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
else
echo "No Llama Stack Server processes found ?!"
fi
echo "Llama Stack Server stopped"
}
# check if server is already running # check if server is already running
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
echo "Llama Stack Server is already running, skipping start" echo "Llama Stack Server is already running, skipping start"
@ -185,14 +204,16 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
done done
echo "" echo ""
fi fi
trap stop_server EXIT ERR INT TERM
fi fi
# Run tests # Run tests
echo "=== Running Integration Tests ===" echo "=== Running Integration Tests ==="
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag" EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
# Additional exclusions for vllm provider # Additional exclusions for vllm setup
if [[ "$PROVIDER" == "vllm" ]]; then if [[ "$TEST_SETUP" == "vllm" ]]; then
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls" EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
fi fi
@ -229,20 +250,22 @@ if [[ -n "$TEST_SUBDIRS" ]]; then
echo "Total test files: $(echo $TEST_FILES | wc -w)" echo "Total test files: $(echo $TEST_FILES | wc -w)"
PYTEST_TARGET="$TEST_FILES" PYTEST_TARGET="$TEST_FILES"
EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2"
else else
PYTEST_TARGET="tests/integration/" PYTEST_TARGET="tests/integration/"
EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE" EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
fi fi
set +e set +e
set -x
pytest -s -v $PYTEST_TARGET \ pytest -s -v $PYTEST_TARGET \
--stack-config="$STACK_CONFIG" \ --stack-config="$STACK_CONFIG" \
--inference-mode="$INFERENCE_MODE" \
-k "$PYTEST_PATTERN" \ -k "$PYTEST_PATTERN" \
$EXTRA_PARAMS \ $EXTRA_PARAMS \
--color=yes \ --color=yes \
--capture=tee-sys --capture=tee-sys
exit_code=$? exit_code=$?
set +x
set -e set -e
if [ $exit_code -eq 0 ]; then if [ $exit_code -eq 0 ]; then
@ -260,18 +283,5 @@ echo "=== System Resources After Tests ==="
free -h 2>/dev/null || echo "free command not available" free -h 2>/dev/null || echo "free command not available"
df -h df -h
# stop server
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
echo "Stopping Llama Stack Server..."
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
else
echo "No Llama Stack Server processes found ?!"
fi
echo "Llama Stack Server stopped"
fi
echo "" echo ""
echo "=== Integration Tests Complete ===" echo "=== Integration Tests Complete ==="

View file

@ -6,9 +6,7 @@ Integration tests verify complete workflows across different providers using Lla
```bash ```bash
# Run all integration tests with existing recordings # Run all integration tests with existing recordings
LLAMA_STACK_TEST_INFERENCE_MODE=replay \ uv run --group test \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
uv run --group test \
pytest -sv tests/integration/ --stack-config=starter pytest -sv tests/integration/ --stack-config=starter
``` ```
@ -42,25 +40,35 @@ Model parameters can be influenced by the following options:
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
if no model is specified. if no model is specified.
### Suites (fast selection + sane defaults) ### Suites and Setups
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly). - `--suite`: single named suite that narrows which tests are collected.
- Available suites: - Available suites:
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model. - `base`: collects most tests (excludes responses and post_training)
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`. - `responses`: collects tests under `tests/integration/responses` (needs strong tool-calling models)
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suites text model. - `vision`: collects only `tests/integration/inference/test_vision_inference.py`
- `--setup`: global configuration that can be used with any suite. Setups prefill model/env defaults; explicit CLI flags always win.
- Available setups:
- `ollama`: Local Ollama provider with lightweight models (sets OLLAMA_URL, uses llama3.2:3b-instruct-fp16)
- `vllm`: VLLM provider for efficient local inference (sets VLLM_URL, uses Llama-3.2-1B-Instruct)
- `gpt`: OpenAI GPT models for high-quality responses (uses gpt-4o)
- `claude`: Anthropic Claude models for high-quality responses (uses claude-3-5-sonnet)
Examples: Examples
```bash ```bash
# Fast responses run with defaults # Fast responses run with a strong tool-calling model
pytest -s -v tests/integration --stack-config=server:starter --suite=responses pytest -s -v tests/integration --stack-config=server:starter --suite=responses --setup=gpt
# Fast single-file vision run with defaults # Fast single-file vision run with Ollama defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=vision pytest -s -v tests/integration --stack-config=server:starter --suite=vision --setup=ollama
# Combine suites and override a default # Base suite with VLLM for performance
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small pytest -s -v tests/integration --stack-config=server:starter --suite=base --setup=vllm
# Override a default from setup
pytest -s -v tests/integration --stack-config=server:starter \
--suite=responses --setup=gpt --embedding-model=text-embedding-3-small
``` ```
## Examples ## Examples
@ -127,14 +135,13 @@ pytest tests/integration/
### RECORD Mode ### RECORD Mode
Captures API interactions for later replay: Captures API interactions for later replay:
```bash ```bash
LLAMA_STACK_TEST_INFERENCE_MODE=record \ pytest tests/integration/inference/test_new_feature.py --inference-mode=record
pytest tests/integration/inference/test_new_feature.py
``` ```
### LIVE Mode ### LIVE Mode
Tests make real API calls (but not recorded): Tests make real API calls (but not recorded):
```bash ```bash
LLAMA_STACK_TEST_INFERENCE_MODE=live pytest tests/integration/ pytest tests/integration/ --inference-mode=live
``` ```
By default, the recording directory is `tests/integration/recordings`. You can override this by setting the `LLAMA_STACK_TEST_RECORDING_DIR` environment variable. By default, the recording directory is `tests/integration/recordings`. You can override this by setting the `LLAMA_STACK_TEST_RECORDING_DIR` environment variable.
@ -155,15 +162,14 @@ cat recordings/responses/abc123.json | jq '.'
#### Remote Re-recording (Recommended) #### Remote Re-recording (Recommended)
Use the automated workflow script for easier re-recording: Use the automated workflow script for easier re-recording:
```bash ```bash
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference,agents" ./scripts/github/schedule-record-workflow.sh --subdirs "inference,agents"
``` ```
See the [main testing guide](../README.md#remote-re-recording-recommended) for full details. See the [main testing guide](../README.md#remote-re-recording-recommended) for full details.
#### Local Re-recording #### Local Re-recording
```bash ```bash
# Re-record specific tests # Re-record specific tests
LLAMA_STACK_TEST_INFERENCE_MODE=record \ pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py --inference-mode=record
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py
``` ```
Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client. Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client.

View file

@ -15,7 +15,7 @@ from dotenv import load_dotenv
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .suites import SUITE_DEFINITIONS from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
logger = get_logger(__name__, category="tests") logger = get_logger(__name__, category="tests")
@ -63,19 +63,33 @@ def pytest_configure(config):
key, value = env_var.split("=", 1) key, value = env_var.split("=", 1)
os.environ[key] = value os.environ[key] = value
suites_raw = config.getoption("--suite") inference_mode = config.getoption("--inference-mode")
suites: list[str] = [] os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = inference_mode
if suites_raw:
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()] suite = config.getoption("--suite")
unknown = [p for p in suites if p not in SUITE_DEFINITIONS] if suite:
if unknown: if suite not in SUITE_DEFINITIONS:
raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}")
# Apply setups (global parameterizations): env + defaults
setup = config.getoption("--setup")
if suite and not setup:
setup = SUITE_DEFINITIONS[suite].default_setup
if setup:
if setup not in SETUP_DEFINITIONS:
raise pytest.UsageError( raise pytest.UsageError(
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" f"Unknown setup '{setup}'. Available: {', '.join(sorted(SETUP_DEFINITIONS.keys()))}"
) )
for suite in suites:
suite_def = SUITE_DEFINITIONS.get(suite, {}) setup_obj = SETUP_DEFINITIONS[setup]
defaults: dict = suite_def.get("defaults", {}) logger.info(f"Applying setup '{setup}'{' for suite ' + suite if suite else ''}")
for dest, value in defaults.items(): # Apply env first
for k, v in setup_obj.env.items():
if k not in os.environ:
os.environ[k] = str(v)
# Apply defaults if not provided explicitly
for dest, value in setup_obj.defaults.items():
current = getattr(config.option, dest, None) current = getattr(config.option, dest, None)
if not current: if not current:
setattr(config.option, dest, value) setattr(config.option, dest, value)
@ -120,6 +134,13 @@ def pytest_addoption(parser):
default=384, default=384,
help="Output dimensionality of the embedding model to use for testing. Default: 384", help="Output dimensionality of the embedding model to use for testing. Default: 384",
) )
parser.addoption(
"--inference-mode",
help="Inference mode: { record, replay, live } (default: replay)",
choices=["record", "replay", "live"],
default="replay",
)
parser.addoption( parser.addoption(
"--report", "--report",
help="Path where the test report should be written, e.g. --report=/path/to/report.md", help="Path where the test report should be written, e.g. --report=/path/to/report.md",
@ -127,14 +148,18 @@ def pytest_addoption(parser):
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
suite_help = ( suite_help = (
"Comma-separated integration test suites to narrow collection and prefill defaults. " f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses"
"Available: "
f"{available_suites}. "
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
"Examples: --suite=responses or --suite=responses,vision."
) )
parser.addoption("--suite", help=suite_help) parser.addoption("--suite", help=suite_help)
# Global setups for any suite
available_setups = ", ".join(sorted(SETUP_DEFINITIONS.keys()))
setup_help = (
f"Global test setup configuration. Available: {available_setups}. "
"Can be used with any suite. Example: --setup=ollama"
)
parser.addoption("--setup", help=setup_help)
MODEL_SHORT_IDS = { MODEL_SHORT_IDS = {
"meta-llama/Llama-3.2-3B-Instruct": "3B", "meta-llama/Llama-3.2-3B-Instruct": "3B",
@ -221,16 +246,12 @@ pytest_plugins = ["tests.integration.fixtures.common"]
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
"""Skip collecting paths outside the selected suite roots for speed.""" """Skip collecting paths outside the selected suite roots for speed."""
suites_raw = config.getoption("--suite") suite = config.getoption("--suite")
if not suites_raw: if not suite:
return False return False
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()] sobj = SUITE_DEFINITIONS.get(suite)
roots: list[str] = [] roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
for name in names:
suite_def = SUITE_DEFINITIONS.get(name)
if suite_def:
roots.extend(suite_def.get("roots", []))
if not roots: if not roots:
return False return False

View file

@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
"remote::gemini", "remote::gemini",
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields # https://docs.anthropic.com/en/api/openai-sdk#simple-fields
"remote::anthropic", "remote::anthropic",
"remote::vertexai",
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")

View file

@ -8,46 +8,112 @@
# For example: # For example:
# #
# ```bash # ```bash
# pytest tests/integration/ --suite=vision # pytest tests/integration/ --suite=vision --setup=ollama
# ``` # ```
# #
# Each suite can: """
# - restrict collection to specific roots (dirs or files) Each suite defines what to run (roots). Suites can be run with different global setups defined in setups.py.
# - provide default CLI option values (e.g. text_model, embedding_model, etc.) Setups provide environment variables and model defaults that can be reused across multiple suites.
CLI examples:
pytest tests/integration --suite=responses --setup=gpt
pytest tests/integration --suite=vision --setup=ollama
pytest tests/integration --suite=base --setup=vllm
"""
from pathlib import Path from pathlib import Path
from pydantic import BaseModel, Field
this_dir = Path(__file__).parent this_dir = Path(__file__).parent
default_roots = [
class Suite(BaseModel):
name: str
roots: list[str]
default_setup: str | None = None
class Setup(BaseModel):
"""A reusable test configuration with environment and CLI defaults."""
name: str
description: str
defaults: dict[str, str] = Field(default_factory=dict)
env: dict[str, str] = Field(default_factory=dict)
# Global setups - can be used with any suite "technically" but in reality, some setups might work
# only for specific test suites.
SETUP_DEFINITIONS: dict[str, Setup] = {
"ollama": Setup(
name="ollama",
description="Local Ollama provider with text + safety models",
env={
"OLLAMA_URL": "http://0.0.0.0:11434",
"SAFETY_MODEL": "ollama/llama-guard3:1b",
},
defaults={
"text_model": "ollama/llama3.2:3b-instruct-fp16",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"safety_model": "ollama/llama-guard3:1b",
"safety_shield": "llama-guard",
},
),
"ollama-vision": Setup(
name="ollama",
description="Local Ollama provider with a vision model",
env={
"OLLAMA_URL": "http://0.0.0.0:11434",
},
defaults={
"vision_model": "ollama/llama3.2-vision:11b",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
),
"vllm": Setup(
name="vllm",
description="vLLM provider with a text model",
env={
"VLLM_URL": "http://localhost:8000/v1",
},
defaults={
"text_model": "vllm/meta-llama/Llama-3.2-1B-Instruct",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
),
"gpt": Setup(
name="gpt",
description="OpenAI GPT models for high-quality responses and tool calling",
defaults={
"text_model": "openai/gpt-4o",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
),
}
base_roots = [
str(p) str(p)
for p in this_dir.glob("*") for p in this_dir.glob("*")
if p.is_dir() if p.is_dir()
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training") and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
] ]
SUITE_DEFINITIONS: dict[str, dict] = { SUITE_DEFINITIONS: dict[str, Suite] = {
"base": { "base": Suite(
"description": "Base suite that includes most tests but runs them with a text Ollama model", name="base",
"roots": default_roots, roots=base_roots,
"defaults": { default_setup="ollama",
"text_model": "ollama/llama3.2:3b-instruct-fp16", ),
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2", "responses": Suite(
}, name="responses",
}, roots=["tests/integration/responses"],
"responses": { default_setup="gpt",
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model", ),
"roots": ["tests/integration/responses"], "vision": Suite(
"defaults": { name="vision",
"text_model": "openai/gpt-4o", roots=["tests/integration/inference/test_vision_inference.py"],
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2", default_setup="ollama-vision",
}, ),
},
"vision": {
"description": "Suite that includes only the vision tests",
"roots": ["tests/integration/inference/test_vision_inference.py"],
"defaults": {
"vision_model": "ollama/llama3.2-vision:11b",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
} }

View file

@ -183,6 +183,110 @@ def test_vector_db_insert_from_url_and_query(
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks) assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_openai_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
# different document formats that should work with OpenAI APIs
documents = [
Document(
document_id="text-doc",
content="This is a plain text document about machine learning algorithms.",
metadata={"type": "text", "category": "AI"},
),
Document(
document_id="url-doc",
content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
mime_type="text/plain",
metadata={"type": "url", "source": "pytorch"},
),
Document(
document_id="data-url-doc",
content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
metadata={"type": "data_url", "encoding": "base64"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= len(documents), (
f"Expected at least {len(documents)} files, got {len(files_list.data)}"
)
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about machine learning and deep learning",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "machine learning" in content_text or "deep learning" in content_text
def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_exception_handling"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="valid-doc",
content="This is a valid document that should be processed successfully.",
metadata={"status": "valid"},
),
Document(
document_id="invalid-url-doc",
content="https://nonexistent-domain-12345.com/invalid.txt",
metadata={"status": "invalid_url"},
),
Document(
document_id="another-valid-doc",
content="This is another valid document for testing resilience.",
metadata={"status": "valid"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="valid document",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "valid document" in content_text
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension): def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"] providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
assert len(providers) > 0 assert len(providers) > 0
@ -249,3 +353,107 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
"chunk_template": "This should raise a ValueError because it is missing the proper template variables", "chunk_template": "This should raise a ValueError because it is missing the proper template variables",
}, },
) )
def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_query_generation_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="ai-doc",
content="Artificial intelligence and machine learning are transforming technology.",
metadata={"category": "AI"},
),
Document(
document_id="banana-doc",
content="Don't bring a banana to a knife fight.",
metadata={"category": "wisdom"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about AI",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "artificial intelligence" in content_text or "machine learning" in content_text
def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_pdf_data_url_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
sample_pdf = b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
import base64
pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
documents = [
Document(
document_id="test-pdf-data-url",
content=pdf_data_url,
metadata={"type": "pdf", "source": "data_url"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
pdf_file = None
for file in files_list.data:
if file.filename and "test-pdf-data-url" in file.filename:
pdf_file = file
break
assert pdf_file is not None, "PDF file should be found in Files API"
assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="sample title",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "sample title" in content_text or "title" in content_text

View file

@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
# Should raise an exception instead of returning empty string # Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError): with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type) content_from_data_and_mime_type(data, mime_type)
async def test_memory_tool_error_handling():
"""Test that memory tool handles various failures gracefully without crashing."""
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
config = RagToolRuntimeConfig()
memory_tool = MemoryToolRuntimeImpl(
config=config,
vector_io_api=AsyncMock(),
inference_api=AsyncMock(),
files_api=AsyncMock(),
)
docs = [
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
]
mock_file1 = MagicMock()
mock_file1.id = "file_good1"
mock_file2 = MagicMock()
mock_file2.id = "file_good2"
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.side_effect = Exception("Bad URL")
mock_client.return_value.__aenter__.return_value = mock_instance
# won't raise exception despite one document failing
await memory_tool.insert(docs, "vector_store_123")
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2

View file

@ -81,3 +81,58 @@ class TestRagQuery:
# Test that invalid mode raises an error # Test that invalid mode raises an error
with pytest.raises(ValueError): with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode") RAGQueryConfig(mode="wrong_mode")
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),
inference_api=MagicMock(),
files_api=MagicMock(),
)
vector_db_ids = ["db1", "db2"]
# Fake chunks from each DB
chunk_metadata1 = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source1",
metadata_token_count=5,
)
chunk1 = Chunk(
content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1",
chunk_metadata=chunk_metadata1,
)
chunk_metadata2 = ChunkMetadata(
document_id="doc2",
chunk_id="chunk2",
source="test_source2",
metadata_token_count=5,
)
chunk2 = Chunk(
content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2",
chunk_metadata=chunk_metadata2,
)
rag_tool.vector_io_api.query_chunks = AsyncMock(
side_effect=[
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
]
)
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"]
returned_vector_db_ids = result.metadata["vector_db_ids"]
assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"]
assert returned_vector_db_ids == ["db1", "db2"]