chore: Updating documentation and adding exception handling for Vector Stores in RAG Tool and updating inference to use openai and updating memory implementation to use existing libraries

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-09-07 13:52:39 -04:00
parent 28696c3f30
commit ff0bd414b1
27 changed files with 926 additions and 403 deletions

View file

@ -5,21 +5,22 @@ inputs:
stack-config:
description: 'Stack configuration to use'
required: true
provider:
description: 'Provider to use for tests'
required: true
setup:
description: 'Setup to use for tests (e.g., ollama, gpt, vllm)'
required: false
default: ''
inference-mode:
description: 'Inference mode (record or replay)'
required: true
test-suite:
suite:
description: 'Test suite to use: base, responses, vision, etc.'
required: false
default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides suite'
required: false
default: ''
test-pattern:
pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
@ -37,14 +38,23 @@ runs:
- name: Run Integration Tests
shell: bash
run: |
uv run --no-sync ./scripts/integration-tests.sh \
--stack-config '${{ inputs.stack-config }}' \
--provider '${{ inputs.provider }}' \
--test-subdirs '${{ inputs.test-subdirs }}' \
--test-pattern '${{ inputs.test-pattern }}' \
--inference-mode '${{ inputs.inference-mode }}' \
--test-suite '${{ inputs.test-suite }}' \
| tee pytest-${{ inputs.inference-mode }}.log
SCRIPT_ARGS="--stack-config ${{ inputs.stack-config }} --inference-mode ${{ inputs.inference-mode }}"
# Add optional arguments only if they are provided
if [ -n '${{ inputs.setup }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --setup ${{ inputs.setup }}"
fi
if [ -n '${{ inputs.suite }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --suite ${{ inputs.suite }}"
fi
if [ -n '${{ inputs.subdirs }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --subdirs ${{ inputs.subdirs }}"
fi
if [ -n '${{ inputs.pattern }}' ]; then
SCRIPT_ARGS="$SCRIPT_ARGS --pattern ${{ inputs.pattern }}"
fi
uv run --no-sync ./scripts/integration-tests.sh $SCRIPT_ARGS | tee pytest-${{ inputs.inference-mode }}.log
- name: Commit and push recordings
@ -58,7 +68,7 @@ runs:
echo "New recordings detected, committing and pushing"
git add tests/integration/recordings/
git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})"
git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})"
git fetch origin ${{ github.ref_name }}
git rebase origin/${{ github.ref_name }}
echo "Rebased successfully"

View file

@ -1,7 +1,7 @@
name: Setup Ollama
description: Start Ollama
inputs:
test-suite:
suite:
description: 'Test suite to use: base, responses, vision, etc.'
required: false
default: ''
@ -11,7 +11,7 @@ runs:
- name: Start Ollama
shell: bash
run: |
if [ "${{ inputs.test-suite }}" == "vision" ]; then
if [ "${{ inputs.suite }}" == "vision" ]; then
image="ollama-with-vision-model"
else
image="ollama-with-models"

View file

@ -8,11 +8,11 @@ inputs:
client-version:
description: 'Client version (latest or published)'
required: true
provider:
description: 'Provider to setup (ollama or vllm)'
required: true
setup:
description: 'Setup to configure (ollama, vllm, gpt, etc.)'
required: false
default: 'ollama'
test-suite:
suite:
description: 'Test suite to use: base, responses, vision, etc.'
required: false
default: ''
@ -30,13 +30,13 @@ runs:
client-version: ${{ inputs.client-version }}
- name: Setup ollama
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
if: ${{ (inputs.setup == 'ollama' || inputs.setup == 'ollama-vision') && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-ollama
with:
test-suite: ${{ inputs.test-suite }}
suite: ${{ inputs.suite }}
- name: Setup vllm
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}
if: ${{ inputs.setup == 'vllm' && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-vllm
- name: Build Llama Stack

View file

@ -28,8 +28,8 @@ on:
description: 'Test against both the latest and published versions'
type: boolean
default: false
test-provider:
description: 'Test against a specific provider'
test-setup:
description: 'Test against a specific setup'
type: string
default: 'ollama'
@ -42,18 +42,18 @@ jobs:
run-replay-mode-tests:
runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }}
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.setup, matrix.python-version, matrix.client-version, matrix.suite) }}
strategy:
fail-fast: false
matrix:
client-type: [library, server]
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }}
# Use vllm on weekly schedule, otherwise use test-setup input (defaults to ollama)
setup: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-setup || 'ollama')) }}
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
test-suite: [base, vision]
suite: [base, vision]
steps:
- name: Checkout repository
@ -64,14 +64,14 @@ jobs:
with:
python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }}
provider: ${{ matrix.provider }}
test-suite: ${{ matrix.test-suite }}
setup: ${{ matrix.setup }}
suite: ${{ matrix.suite }}
inference-mode: 'replay'
- name: Run tests
uses: ./.github/actions/run-and-record-tests
with:
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }}
setup: ${{ matrix.setup }}
inference-mode: 'replay'
test-suite: ${{ matrix.test-suite }}
suite: ${{ matrix.suite }}

View file

@ -48,7 +48,6 @@ jobs:
working-directory: llama_stack/ui
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
continue-on-error: true
env:
SKIP: no-commit-to-branch
RUFF_OUTPUT_FORMAT: github

View file

@ -10,19 +10,19 @@ run-name: Run the integration test suite from tests/integration
on:
workflow_dispatch:
inputs:
test-provider:
description: 'Test against a specific provider'
test-setup:
description: 'Test against a specific setup'
type: string
default: 'ollama'
test-suite:
suite:
description: 'Test suite to use: base, responses, vision, etc.'
type: string
default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides suite'
type: string
default: ''
test-pattern:
pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
default: ''
@ -39,10 +39,10 @@ jobs:
run: |
echo "::group::Workflow Inputs"
echo "branch: ${{ github.ref_name }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "test-suite: ${{ inputs.test-suite }}"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "test-setup: ${{ inputs.test-setup }}"
echo "suite: ${{ inputs.suite }}"
echo "subdirs: ${{ inputs.subdirs }}"
echo "pattern: ${{ inputs.pattern }}"
echo "::endgroup::"
- name: Checkout repository
@ -55,16 +55,16 @@ jobs:
with:
python-version: "3.12" # Use single Python version for recording
client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }}
test-suite: ${{ inputs.test-suite }}
setup: ${{ inputs.test-setup || 'ollama' }}
suite: ${{ inputs.suite }}
inference-mode: 'record'
- name: Run and record tests
uses: ./.github/actions/run-and-record-tests
with:
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }}
setup: ${{ inputs.test-setup || 'ollama' }}
inference-mode: 'record'
test-suite: ${{ inputs.test-suite }}
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}
suite: ${{ inputs.suite }}
subdirs: ${{ inputs.subdirs }}
pattern: ${{ inputs.pattern }}

View file

@ -93,10 +93,31 @@ chunks_response = client.vector_io.query(
### Using the RAG Tool
> **⚠️ DEPRECATION NOTICE**: The RAG Tool is being deprecated in favor of directly using the OpenAI-compatible Search
> API. We recommend migrating to the OpenAI APIs for better compatibility and future support.
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
[appendix](#more-ragdocument-examples).
#### OpenAI API Integration & Migration
The RAG tool has been updated to use OpenAI-compatible APIs. This provides several benefits:
- **Files API Integration**: Documents are now uploaded using OpenAI's file upload endpoints
- **Vector Stores API**: Vector storage operations use OpenAI's vector store format with configurable chunking strategies
- **Error Resilience:** When processing multiple documents, individual failures are logged but don't crash the operation. Failed documents are skipped while successful ones continue processing.
**Migration Path:**
We recommend migrating to the OpenAI-compatible Search API for:
1. **Better OpenAI Ecosystem Integration**: Direct compatibility with OpenAI tools and workflows including the Responses API
2**Future-Proof**: Continued support and feature development
3**Full OpenAI Compatibility**: Vector Stores, Files, and Search APIs are fully compatible with OpenAI's Responses API
The OpenAI APIs are used under the hood, so you can continue to use your existing RAG Tool code with minimal changes.
However, we recommend updating your code to use the new OpenAI-compatible APIs for better long-term support. If any
documents fail to process, they will be logged in the response but will not cause the entire operation to fail.
```python
from llama_stack_client import RAGDocument

View file

@ -6,6 +6,7 @@ data:
apis:
- agents
- inference
- files
- safety
- telemetry
- tool_runtime
@ -19,13 +20,6 @@ data:
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
@ -41,6 +35,14 @@ data:
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
@ -111,9 +113,6 @@ data:
- model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
- model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []

View file

@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo
apis:
- agents
- inference
- files
- safety
- telemetry
- tool_runtime
@ -31,6 +32,14 @@ providers:
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -1,137 +1,55 @@
apiVersion: v1
data:
stack_run_config.yaml: |
version: '2'
image_name: kubernetes-demo
apis:
- agents
- inference
- safety
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: vllm-inference
provider_type: remote::vllm
config:
url: ${env.VLLM_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
api_token: ${env.VLLM_API_TOKEN:=fake}
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb
config:
url: ${env.CHROMADB_URL:=}
kvstore:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
responses_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
sinks: ${env.TELEMETRY_SINKS:=console}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
max_results: 3
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
table_name: llamastack_kvstore
inference_store:
type: postgres
host: ${env.POSTGRES_HOST:=localhost}
port: ${env.POSTGRES_PORT:=5432}
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
models:
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
model_type: embedding
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
provider_id: vllm-safety
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321
auth:
provider_config:
type: github_token
stack_run_config.yaml: "version: '2'\nimage_name: kubernetes-demo\napis:\n- agents\n-
inference\n- files\n- safety\n- telemetry\n- tool_runtime\n- vector_io\nproviders:\n
\ inference:\n - provider_id: vllm-inference\n provider_type: remote::vllm\n
\ config:\n url: ${env.VLLM_URL:=http://localhost:8000/v1}\n max_tokens:
${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n tls_verify:
${env.VLLM_TLS_VERIFY:=true}\n - provider_id: vllm-safety\n provider_type:
remote::vllm\n config:\n url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}\n
\ max_tokens: ${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n
\ tls_verify: ${env.VLLM_TLS_VERIFY:=true}\n - provider_id: sentence-transformers\n
\ provider_type: inline::sentence-transformers\n config: {}\n vector_io:\n
\ - provider_id: ${env.ENABLE_CHROMADB:+chromadb}\n provider_type: remote::chromadb\n
\ config:\n url: ${env.CHROMADB_URL:=}\n kvstore:\n type: postgres\n
\ host: ${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n
\ db: ${env.POSTGRES_DB:=llamastack}\n user: ${env.POSTGRES_USER:=llamastack}\n
\ password: ${env.POSTGRES_PASSWORD:=llamastack}\n files:\n - provider_id:
meta-reference-files\n provider_type: inline::localfs\n config:\n storage_dir:
${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}\n metadata_store:\n
\ type: sqlite\n db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
\ \n safety:\n - provider_id: llama-guard\n provider_type: inline::llama-guard\n
\ config:\n excluded_categories: []\n agents:\n - provider_id: meta-reference\n
\ provider_type: inline::meta-reference\n config:\n persistence_store:\n
\ type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n port:
${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
\ responses_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
\ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
\ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
\ telemetry:\n - provider_id: meta-reference\n provider_type: inline::meta-reference\n
\ config:\n service_name: \"${env.OTEL_SERVICE_NAME:=\\u200B}\"\n sinks:
${env.TELEMETRY_SINKS:=console}\n tool_runtime:\n - provider_id: brave-search\n
\ provider_type: remote::brave-search\n config:\n api_key: ${env.BRAVE_SEARCH_API_KEY:+}\n
\ max_results: 3\n - provider_id: tavily-search\n provider_type: remote::tavily-search\n
\ config:\n api_key: ${env.TAVILY_SEARCH_API_KEY:+}\n max_results:
3\n - provider_id: rag-runtime\n provider_type: inline::rag-runtime\n config:
{}\n - provider_id: model-context-protocol\n provider_type: remote::model-context-protocol\n
\ config: {}\nmetadata_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
\ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
\ table_name: llamastack_kvstore\ninference_store:\n type: postgres\n host:
${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
\ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\nmodels:\n-
metadata:\n embedding_dimension: 384\n model_id: all-MiniLM-L6-v2\n provider_id:
sentence-transformers\n model_type: embedding\n- metadata: {}\n model_id: ${env.INFERENCE_MODEL}\n
\ provider_id: vllm-inference\n model_type: llm\n- metadata: {}\n model_id:
${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\n provider_id: vllm-safety\n
\ model_type: llm\nshields:\n- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\nvector_dbs:
[]\ndatasets: []\nscoring_fns: []\nbenchmarks: []\ntool_groups:\n- toolgroup_id:
builtin::websearch\n provider_id: tavily-search\n- toolgroup_id: builtin::rag\n
\ provider_id: rag-runtime\nserver:\n port: 8321\n auth:\n provider_config:\n
\ type: github_token\n"
kind: ConfigMap
metadata:
creationTimestamp: null

View file

@ -3,6 +3,7 @@ image_name: kubernetes-demo
apis:
- agents
- inference
- files
- safety
- telemetry
- tool_runtime
@ -38,6 +39,14 @@ providers:
db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard

View file

@ -45,6 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
from llama_stack.core.utils.exec import formulate_run_args, run_command
from llama_stack.core.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
@ -294,6 +295,12 @@ def _generate_run_config(
if build_config.external_providers_dir
else EXTERNAL_PROVIDERS_DIR,
)
if not run_config.inference_store:
run_config.inference_store = SqliteSqlStoreConfig(
**SqliteSqlStoreConfig.sample_run_config(
__distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
)
)
# build providers dict
provider_registry = get_provider_registry(build_config)
for api in apis:

View file

@ -10,7 +10,6 @@ import json
import logging # allow-direct-logging
import os
import sys
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from io import BytesIO
from pathlib import Path
@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()

View file

@ -8,7 +8,7 @@
from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import UserMessage
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
messages = [interleaved_content_as_str(content)]
template = Template(config.template)
content = template.render({"messages": messages})
rendered_content: str = template.render({"messages": messages})
model = config.model
message = UserMessage(content=content)
response = await inference_api.chat_completion(
model_id=model,
message = OpenAIUserMessageParam(content=rendered_content)
response = await inference_api.openai_chat_completion(
model=model,
messages=[message],
stream=False,
)
query = response.completion_message.content
query = response.choices[0].message.content
return query

View file

@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
parse_data_url,
)
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
"""Get raw binary data and mime type from a RAGDocument for file upload."""
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
async with httpx.AsyncClient() as client:
r = await client.get(doc.content.uri)
r.raise_for_status()
mime_type = r.headers.get("content-type", "application/octet-stream")
return r.content, mime_type
else:
if isinstance(doc.content, str):
content_str = doc.content
else:
content_str = interleaved_content_as_str(doc.content)
if content_str.startswith("data:"):
parts = parse_data_url(content_str)
mime_type = parts["mimetype"]
data = parts["data"]
if parts["is_base64"]:
file_data = base64.b64decode(data)
else:
file_data = data.encode("utf-8")
return file_data, mime_type
else:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
def __init__(
self,
@ -95,20 +133,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
return
for doc in documents:
if isinstance(doc.content, URL):
if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
mime_type = parts["mimetype"]
else:
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
try:
try:
file_data, mime_type = await raw_data_from_doc(doc)
except Exception as e:
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
continue
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
@ -118,9 +148,13 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
upload_file = UploadFile(file=file_obj, filename=filename)
try:
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
except Exception as e:
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
continue
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
@ -129,12 +163,22 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
)
try:
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
)
except Exception as e:
log.error(
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
)
continue
except Exception as e:
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
continue
async def query(
self,
@ -167,8 +211,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
for vector_db_id in vector_db_ids
]
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
chunks = [c for r in results for c in r.chunks]
scores = [s for r in results for s in r.scores]
chunks = []
scores = []
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
for chunk, score in zip(result.chunks, result.scores, strict=False):
if not hasattr(chunk, "metadata") or chunk.metadata is None:
chunk.metadata = {}
chunk.metadata["vector_db_id"] = vector_db_id
chunks.append(chunk)
scores.append(score)
if not chunks:
return RAGQueryResult(content=None)
@ -203,6 +257,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
metadata_keys_to_exclude_from_context = [
"token_count",
"metadata_token_count",
"vector_db_id",
]
metadata_for_context = {}
for k in chunk_metadata_keys_to_include_from_context:
@ -227,6 +282,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
"chunks": [c.content for c in chunks[: len(picked)]],
"scores": scores[: len(picked)],
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
},
)
@ -262,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
else:
# handle someone passing an empty dict
query_config = RAGQueryConfig()
query = kwargs["query"]
@ -273,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
)
return ToolInvocationResult(
content=result.content,
content=result.content or [],
metadata=result.metadata,
)

View file

@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"],
pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",

View file

@ -6,16 +6,20 @@
from typing import Any
import google.auth.transport.requests
from google.auth import default
from llama_stack.apis.inference import ChatCompletionRequest
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
LiteLLMOpenAIMixin,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import VertexAIConfig
from .models import MODEL_ENTRIES
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: VertexAIConfig) -> None:
LiteLLMOpenAIMixin.__init__(
self,
@ -27,10 +31,31 @@ class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
self.config = config
def get_api_key(self) -> str:
# Vertex AI doesn't use API keys, it uses Application Default Credentials
# Return empty string to let litellm handle authentication via ADC
"""
Get an access token for Vertex AI using Application Default Credentials.
Vertex AI uses ADC instead of API keys. This method obtains an access token
from the default credentials and returns it for use with the OpenAI-compatible client.
"""
try:
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(google.auth.transport.requests.Request())
return str(credentials.token)
except Exception:
# If we can't get credentials, return empty string to let LiteLLM handle it
# This allows the LiteLLM mixin to work with ADC directly
return ""
def get_base_url(self) -> str:
"""
Get the Vertex AI OpenAI-compatible API base URL.
Returns the Vertex AI OpenAI-compatible endpoint URL.
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
"""
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
# Get base parameters from parent
params = await super()._get_params(request)

71
scripts/get_setup_env.py Executable file
View file

@ -0,0 +1,71 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
"""
Small helper script to extract environment variables from a test setup.
Used by integration-tests.sh to set environment variables before starting the server.
"""
import argparse
import sys
from tests.integration.suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
def get_setup_env_vars(setup_name, suite_name=None):
"""
Get environment variables for a setup, with optional suite default fallback.
Args:
setup_name: Name of the setup (e.g., 'ollama', 'gpt')
suite_name: Optional suite name to get default setup if setup_name is None
Returns:
Dictionary of environment variables
"""
# If no setup specified, try to get default from suite
if not setup_name and suite_name:
suite = SUITE_DEFINITIONS.get(suite_name)
if suite and suite.default_setup:
setup_name = suite.default_setup
if not setup_name:
return {}
setup = SETUP_DEFINITIONS.get(setup_name)
if not setup:
print(
f"Error: Unknown setup '{setup_name}'. Available: {', '.join(sorted(SETUP_DEFINITIONS.keys()))}",
file=sys.stderr,
)
sys.exit(1)
return setup.env
def main():
parser = argparse.ArgumentParser(description="Extract environment variables from a test setup")
parser.add_argument("--setup", help="Setup name (e.g., ollama, gpt)")
parser.add_argument("--suite", help="Suite name to get default setup from if --setup not provided")
parser.add_argument("--format", choices=["bash", "json"], default="bash", help="Output format (default: bash)")
args = parser.parse_args()
env_vars = get_setup_env_vars(args.setup, args.suite)
if args.format == "bash":
# Output as bash export statements
for key, value in env_vars.items():
print(f"export {key}='{value}'")
elif args.format == "json":
import json
print(json.dumps(env_vars))
if __name__ == "__main__":
main()

View file

@ -14,7 +14,7 @@ set -euo pipefail
# Default values
BRANCH=""
TEST_SUBDIRS=""
TEST_PROVIDER="ollama"
TEST_SETUP="ollama"
TEST_SUITE="base"
TEST_PATTERN=""
@ -27,24 +27,24 @@ Trigger the integration test recording workflow remotely. This way you do not ne
OPTIONS:
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama)
-t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k
-t, --suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
-p, --setup SETUP Test setup to use: vllm, ollama, gpt, etc. (default: ollama)
-s, --subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
-k, --pattern PATTERN Regex pattern to pass to pytest -k
-h, --help Show this help message
EXAMPLES:
# Record tests for current branch with agents subdirectory
$0 --test-subdirs "agents"
$0 --subdirs "agents"
# Record tests for specific branch with vision tests
$0 -b my-feature-branch --test-suite vision
$0 -b my-feature-branch --suite vision
# Record multiple test subdirectories with specific provider
$0 --test-subdirs "agents,inference" --test-provider vllm
# Record multiple test subdirectories with specific setup
$0 --subdirs "agents,inference" --setup vllm
# Record tests matching a specific pattern
$0 --test-subdirs "inference" --test-pattern "test_streaming"
$0 --subdirs "inference" --pattern "test_streaming"
EOF
}
@ -63,19 +63,19 @@ while [[ $# -gt 0 ]]; do
BRANCH="$2"
shift 2
;;
-s|--test-subdirs)
-s|--subdirs)
TEST_SUBDIRS="$2"
shift 2
;;
-p|--test-provider)
TEST_PROVIDER="$2"
-p|--setup)
TEST_SETUP="$2"
shift 2
;;
-t|--test-suite)
-t|--suite)
TEST_SUITE="$2"
shift 2
;;
-k|--test-pattern)
-k|--pattern)
TEST_PATTERN="$2"
shift 2
;;
@ -93,21 +93,16 @@ done
# Validate required parameters
if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
echo "Error: --test-subdirs or --test-suite is required"
echo "Error: --subdirs or --suite is required"
echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
echo " $0 --test-subdirs \"agents,inference\""
echo " $0 --test-suite vision"
echo " $0 --subdirs \"agents,inference\""
echo " $0 --suite vision"
echo ""
exit 1
fi
# Validate test provider
if [[ "$TEST_PROVIDER" != "vllm" && "$TEST_PROVIDER" != "ollama" ]]; then
echo "❌ Error: Invalid test provider '$TEST_PROVIDER'"
echo " Supported providers: vllm, ollama"
echo " Example: $0 --test-subdirs \"agents\" --test-provider vllm"
exit 1
fi
# Validate test setup (optional - setups are validated by the workflow itself)
# Common setups: ollama, vllm, gpt, etc.
# Check if required tools are installed
if ! command -v gh &> /dev/null; then
@ -237,7 +232,7 @@ fi
# Build the workflow dispatch command
echo "Triggering integration test recording workflow..."
echo "Branch: $BRANCH"
echo "Test provider: $TEST_PROVIDER"
echo "Test setup: $TEST_SETUP"
echo "Test subdirs: $TEST_SUBDIRS"
echo "Test suite: $TEST_SUITE"
echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
@ -245,16 +240,16 @@ echo ""
# Prepare inputs for gh workflow run
if [[ -n "$TEST_SUBDIRS" ]]; then
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
INPUTS="-f subdirs='$TEST_SUBDIRS'"
fi
if [[ -n "$TEST_PROVIDER" ]]; then
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'"
if [[ -n "$TEST_SETUP" ]]; then
INPUTS="$INPUTS -f test-setup='$TEST_SETUP'"
fi
if [[ -n "$TEST_SUITE" ]]; then
INPUTS="$INPUTS -f test-suite='$TEST_SUITE'"
INPUTS="$INPUTS -f suite='$TEST_SUITE'"
fi
if [[ -n "$TEST_PATTERN" ]]; then
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'"
INPUTS="$INPUTS -f pattern='$TEST_PATTERN'"
fi
# Run the workflow

View file

@ -13,10 +13,10 @@ set -euo pipefail
# Default values
STACK_CONFIG=""
PROVIDER=""
TEST_SUITE="base"
TEST_SETUP=""
TEST_SUBDIRS=""
TEST_PATTERN=""
TEST_SUITE="base"
INFERENCE_MODE="replay"
EXTRA_PARAMS=""
@ -27,29 +27,30 @@ Usage: $0 [OPTIONS]
Options:
--stack-config STRING Stack configuration to use (required)
--provider STRING Provider to use (ollama, vllm, etc.) (required)
--test-suite STRING Comma-separated list of test suites to run (default: 'base')
--suite STRING Test suite to run (default: 'base')
--setup STRING Test setup (models, env) to use (e.g., 'ollama', 'ollama-vision', 'gpt', 'vllm')
--inference-mode STRING Inference mode: record or replay (default: replay)
--test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--test-pattern STRING Regex pattern to pass to pytest -k
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--pattern STRING Regex pattern to pass to pytest -k
--help Show this help message
Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options.
Suites are defined in tests/integration/suites.py and define which tests to run.
Setups are defined in tests/integration/setups.py and provide global configuration (models, env).
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
Examples:
# Basic inference tests with ollama
$0 --stack-config server:ci-tests --provider ollama
$0 --stack-config server:ci-tests --suite base --setup ollama
# Multiple test directories with vllm
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents'
$0 --stack-config server:ci-tests --subdirs 'inference,agents' --setup vllm
# Vision tests with ollama
$0 --stack-config server:ci-tests --provider ollama --test-suite vision
$0 --stack-config server:ci-tests --suite vision # default setup for this suite is ollama-vision
# Record mode for updating test recordings
$0 --stack-config server:ci-tests --provider ollama --inference-mode record
$0 --stack-config server:ci-tests --suite base --inference-mode record
EOF
}
@ -60,15 +61,15 @@ while [[ $# -gt 0 ]]; do
STACK_CONFIG="$2"
shift 2
;;
--provider)
PROVIDER="$2"
--setup)
TEST_SETUP="$2"
shift 2
;;
--test-subdirs)
--subdirs)
TEST_SUBDIRS="$2"
shift 2
;;
--test-suite)
--suite)
TEST_SUITE="$2"
shift 2
;;
@ -76,7 +77,7 @@ while [[ $# -gt 0 ]]; do
INFERENCE_MODE="$2"
shift 2
;;
--test-pattern)
--pattern)
TEST_PATTERN="$2"
shift 2
;;
@ -96,11 +97,13 @@ done
# Validate required parameters
if [[ -z "$STACK_CONFIG" ]]; then
echo "Error: --stack-config is required"
usage
exit 1
fi
if [[ -z "$PROVIDER" ]]; then
echo "Error: --provider is required"
if [[ -z "$TEST_SETUP" && -n "$TEST_SUBDIRS" ]]; then
echo "Error: --test-setup is required when --test-subdirs is provided"
usage
exit 1
fi
@ -111,7 +114,7 @@ fi
echo "=== Llama Stack Integration Test Runner ==="
echo "Stack Config: $STACK_CONFIG"
echo "Provider: $PROVIDER"
echo "Setup: $TEST_SETUP"
echo "Inference Mode: $INFERENCE_MODE"
echo "Test Suite: $TEST_SUITE"
echo "Test Subdirs: $TEST_SUBDIRS"
@ -129,21 +132,25 @@ echo ""
# Set environment variables
export LLAMA_STACK_CLIENT_TIMEOUT=300
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
# Configure provider-specific settings
if [[ "$PROVIDER" == "ollama" ]]; then
export OLLAMA_URL="http://0.0.0.0:11434"
export TEXT_MODEL="ollama/llama3.2:3b-instruct-fp16"
export SAFETY_MODEL="ollama/llama-guard3:1b"
EXTRA_PARAMS="--safety-shield=llama-guard"
else
export VLLM_URL="http://localhost:8000/v1"
export TEXT_MODEL="vllm/meta-llama/Llama-3.2-1B-Instruct"
EXTRA_PARAMS=""
fi
THIS_DIR=$(dirname "$0")
if [[ -n "$TEST_SETUP" ]]; then
EXTRA_PARAMS="--setup=$TEST_SETUP"
fi
# Apply setup-specific environment variables (needed for server startup and tests)
echo "=== Applying Setup Environment Variables ==="
# the server needs this
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
SETUP_ENV=$(PYTHONPATH=$THIS_DIR/.. python "$THIS_DIR/get_setup_env.py" --suite "$TEST_SUITE" --setup "$TEST_SETUP" --format bash)
echo "Setting up environment variables:"
echo "$SETUP_ENV"
eval "$SETUP_ENV"
echo ""
ROOT_DIR="$THIS_DIR/.."
cd $ROOT_DIR
@ -162,6 +169,18 @@ fi
# Start Llama Stack Server if needed
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
stop_server() {
echo "Stopping Llama Stack Server..."
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
else
echo "No Llama Stack Server processes found ?!"
fi
echo "Llama Stack Server stopped"
}
# check if server is already running
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
echo "Llama Stack Server is already running, skipping start"
@ -185,14 +204,16 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
done
echo ""
fi
trap stop_server EXIT ERR INT TERM
fi
# Run tests
echo "=== Running Integration Tests ==="
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
# Additional exclusions for vllm provider
if [[ "$PROVIDER" == "vllm" ]]; then
# Additional exclusions for vllm setup
if [[ "$TEST_SETUP" == "vllm" ]]; then
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
fi
@ -229,20 +250,22 @@ if [[ -n "$TEST_SUBDIRS" ]]; then
echo "Total test files: $(echo $TEST_FILES | wc -w)"
PYTEST_TARGET="$TEST_FILES"
EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2"
else
PYTEST_TARGET="tests/integration/"
EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
fi
set +e
set -x
pytest -s -v $PYTEST_TARGET \
--stack-config="$STACK_CONFIG" \
--inference-mode="$INFERENCE_MODE" \
-k "$PYTEST_PATTERN" \
$EXTRA_PARAMS \
--color=yes \
--capture=tee-sys
exit_code=$?
set +x
set -e
if [ $exit_code -eq 0 ]; then
@ -260,18 +283,5 @@ echo "=== System Resources After Tests ==="
free -h 2>/dev/null || echo "free command not available"
df -h
# stop server
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
echo "Stopping Llama Stack Server..."
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
else
echo "No Llama Stack Server processes found ?!"
fi
echo "Llama Stack Server stopped"
fi
echo ""
echo "=== Integration Tests Complete ==="

View file

@ -6,9 +6,7 @@ Integration tests verify complete workflows across different providers using Lla
```bash
# Run all integration tests with existing recordings
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
uv run --group test \
uv run --group test \
pytest -sv tests/integration/ --stack-config=starter
```
@ -42,25 +40,35 @@ Model parameters can be influenced by the following options:
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
if no model is specified.
### Suites (fast selection + sane defaults)
### Suites and Setups
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
- `--suite`: single named suite that narrows which tests are collected.
- Available suites:
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suites text model.
- `base`: collects most tests (excludes responses and post_training)
- `responses`: collects tests under `tests/integration/responses` (needs strong tool-calling models)
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`
- `--setup`: global configuration that can be used with any suite. Setups prefill model/env defaults; explicit CLI flags always win.
- Available setups:
- `ollama`: Local Ollama provider with lightweight models (sets OLLAMA_URL, uses llama3.2:3b-instruct-fp16)
- `vllm`: VLLM provider for efficient local inference (sets VLLM_URL, uses Llama-3.2-1B-Instruct)
- `gpt`: OpenAI GPT models for high-quality responses (uses gpt-4o)
- `claude`: Anthropic Claude models for high-quality responses (uses claude-3-5-sonnet)
Examples:
Examples
```bash
# Fast responses run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
# Fast responses run with a strong tool-calling model
pytest -s -v tests/integration --stack-config=server:starter --suite=responses --setup=gpt
# Fast single-file vision run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
# Fast single-file vision run with Ollama defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=vision --setup=ollama
# Combine suites and override a default
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
# Base suite with VLLM for performance
pytest -s -v tests/integration --stack-config=server:starter --suite=base --setup=vllm
# Override a default from setup
pytest -s -v tests/integration --stack-config=server:starter \
--suite=responses --setup=gpt --embedding-model=text-embedding-3-small
```
## Examples
@ -127,14 +135,13 @@ pytest tests/integration/
### RECORD Mode
Captures API interactions for later replay:
```bash
LLAMA_STACK_TEST_INFERENCE_MODE=record \
pytest tests/integration/inference/test_new_feature.py
pytest tests/integration/inference/test_new_feature.py --inference-mode=record
```
### LIVE Mode
Tests make real API calls (but not recorded):
```bash
LLAMA_STACK_TEST_INFERENCE_MODE=live pytest tests/integration/
pytest tests/integration/ --inference-mode=live
```
By default, the recording directory is `tests/integration/recordings`. You can override this by setting the `LLAMA_STACK_TEST_RECORDING_DIR` environment variable.
@ -155,15 +162,14 @@ cat recordings/responses/abc123.json | jq '.'
#### Remote Re-recording (Recommended)
Use the automated workflow script for easier re-recording:
```bash
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference,agents"
./scripts/github/schedule-record-workflow.sh --subdirs "inference,agents"
```
See the [main testing guide](../README.md#remote-re-recording-recommended) for full details.
#### Local Re-recording
```bash
# Re-record specific tests
LLAMA_STACK_TEST_INFERENCE_MODE=record \
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py --inference-mode=record
```
Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client.

View file

@ -15,7 +15,7 @@ from dotenv import load_dotenv
from llama_stack.log import get_logger
from .suites import SUITE_DEFINITIONS
from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
logger = get_logger(__name__, category="tests")
@ -63,19 +63,33 @@ def pytest_configure(config):
key, value = env_var.split("=", 1)
os.environ[key] = value
suites_raw = config.getoption("--suite")
suites: list[str] = []
if suites_raw:
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
if unknown:
inference_mode = config.getoption("--inference-mode")
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = inference_mode
suite = config.getoption("--suite")
if suite:
if suite not in SUITE_DEFINITIONS:
raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}")
# Apply setups (global parameterizations): env + defaults
setup = config.getoption("--setup")
if suite and not setup:
setup = SUITE_DEFINITIONS[suite].default_setup
if setup:
if setup not in SETUP_DEFINITIONS:
raise pytest.UsageError(
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
f"Unknown setup '{setup}'. Available: {', '.join(sorted(SETUP_DEFINITIONS.keys()))}"
)
for suite in suites:
suite_def = SUITE_DEFINITIONS.get(suite, {})
defaults: dict = suite_def.get("defaults", {})
for dest, value in defaults.items():
setup_obj = SETUP_DEFINITIONS[setup]
logger.info(f"Applying setup '{setup}'{' for suite ' + suite if suite else ''}")
# Apply env first
for k, v in setup_obj.env.items():
if k not in os.environ:
os.environ[k] = str(v)
# Apply defaults if not provided explicitly
for dest, value in setup_obj.defaults.items():
current = getattr(config.option, dest, None)
if not current:
setattr(config.option, dest, value)
@ -120,6 +134,13 @@ def pytest_addoption(parser):
default=384,
help="Output dimensionality of the embedding model to use for testing. Default: 384",
)
parser.addoption(
"--inference-mode",
help="Inference mode: { record, replay, live } (default: replay)",
choices=["record", "replay", "live"],
default="replay",
)
parser.addoption(
"--report",
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
@ -127,14 +148,18 @@ def pytest_addoption(parser):
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
suite_help = (
"Comma-separated integration test suites to narrow collection and prefill defaults. "
"Available: "
f"{available_suites}. "
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
"Examples: --suite=responses or --suite=responses,vision."
f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses"
)
parser.addoption("--suite", help=suite_help)
# Global setups for any suite
available_setups = ", ".join(sorted(SETUP_DEFINITIONS.keys()))
setup_help = (
f"Global test setup configuration. Available: {available_setups}. "
"Can be used with any suite. Example: --setup=ollama"
)
parser.addoption("--setup", help=setup_help)
MODEL_SHORT_IDS = {
"meta-llama/Llama-3.2-3B-Instruct": "3B",
@ -221,16 +246,12 @@ pytest_plugins = ["tests.integration.fixtures.common"]
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
"""Skip collecting paths outside the selected suite roots for speed."""
suites_raw = config.getoption("--suite")
if not suites_raw:
suite = config.getoption("--suite")
if not suite:
return False
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
roots: list[str] = []
for name in names:
suite_def = SUITE_DEFINITIONS.get(name)
if suite_def:
roots.extend(suite_def.get("roots", []))
sobj = SUITE_DEFINITIONS.get(suite)
roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
if not roots:
return False

View file

@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
"remote::gemini",
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields
"remote::anthropic",
"remote::vertexai",
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")

View file

@ -8,46 +8,112 @@
# For example:
#
# ```bash
# pytest tests/integration/ --suite=vision
# pytest tests/integration/ --suite=vision --setup=ollama
# ```
#
# Each suite can:
# - restrict collection to specific roots (dirs or files)
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
"""
Each suite defines what to run (roots). Suites can be run with different global setups defined in setups.py.
Setups provide environment variables and model defaults that can be reused across multiple suites.
CLI examples:
pytest tests/integration --suite=responses --setup=gpt
pytest tests/integration --suite=vision --setup=ollama
pytest tests/integration --suite=base --setup=vllm
"""
from pathlib import Path
from pydantic import BaseModel, Field
this_dir = Path(__file__).parent
default_roots = [
class Suite(BaseModel):
name: str
roots: list[str]
default_setup: str | None = None
class Setup(BaseModel):
"""A reusable test configuration with environment and CLI defaults."""
name: str
description: str
defaults: dict[str, str] = Field(default_factory=dict)
env: dict[str, str] = Field(default_factory=dict)
# Global setups - can be used with any suite "technically" but in reality, some setups might work
# only for specific test suites.
SETUP_DEFINITIONS: dict[str, Setup] = {
"ollama": Setup(
name="ollama",
description="Local Ollama provider with text + safety models",
env={
"OLLAMA_URL": "http://0.0.0.0:11434",
"SAFETY_MODEL": "ollama/llama-guard3:1b",
},
defaults={
"text_model": "ollama/llama3.2:3b-instruct-fp16",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
"safety_model": "ollama/llama-guard3:1b",
"safety_shield": "llama-guard",
},
),
"ollama-vision": Setup(
name="ollama",
description="Local Ollama provider with a vision model",
env={
"OLLAMA_URL": "http://0.0.0.0:11434",
},
defaults={
"vision_model": "ollama/llama3.2-vision:11b",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
),
"vllm": Setup(
name="vllm",
description="vLLM provider with a text model",
env={
"VLLM_URL": "http://localhost:8000/v1",
},
defaults={
"text_model": "vllm/meta-llama/Llama-3.2-1B-Instruct",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
),
"gpt": Setup(
name="gpt",
description="OpenAI GPT models for high-quality responses and tool calling",
defaults={
"text_model": "openai/gpt-4o",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
),
}
base_roots = [
str(p)
for p in this_dir.glob("*")
if p.is_dir()
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
]
SUITE_DEFINITIONS: dict[str, dict] = {
"base": {
"description": "Base suite that includes most tests but runs them with a text Ollama model",
"roots": default_roots,
"defaults": {
"text_model": "ollama/llama3.2:3b-instruct-fp16",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"responses": {
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
"roots": ["tests/integration/responses"],
"defaults": {
"text_model": "openai/gpt-4o",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"vision": {
"description": "Suite that includes only the vision tests",
"roots": ["tests/integration/inference/test_vision_inference.py"],
"defaults": {
"vision_model": "ollama/llama3.2-vision:11b",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
SUITE_DEFINITIONS: dict[str, Suite] = {
"base": Suite(
name="base",
roots=base_roots,
default_setup="ollama",
),
"responses": Suite(
name="responses",
roots=["tests/integration/responses"],
default_setup="gpt",
),
"vision": Suite(
name="vision",
roots=["tests/integration/inference/test_vision_inference.py"],
default_setup="ollama-vision",
),
}

View file

@ -183,6 +183,110 @@ def test_vector_db_insert_from_url_and_query(
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_openai_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
# different document formats that should work with OpenAI APIs
documents = [
Document(
document_id="text-doc",
content="This is a plain text document about machine learning algorithms.",
metadata={"type": "text", "category": "AI"},
),
Document(
document_id="url-doc",
content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
mime_type="text/plain",
metadata={"type": "url", "source": "pytorch"},
),
Document(
document_id="data-url-doc",
content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
metadata={"type": "data_url", "encoding": "base64"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= len(documents), (
f"Expected at least {len(documents)} files, got {len(files_list.data)}"
)
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about machine learning and deep learning",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "machine learning" in content_text or "deep learning" in content_text
def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_exception_handling"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="valid-doc",
content="This is a valid document that should be processed successfully.",
metadata={"status": "valid"},
),
Document(
document_id="invalid-url-doc",
content="https://nonexistent-domain-12345.com/invalid.txt",
metadata={"status": "invalid_url"},
),
Document(
document_id="another-valid-doc",
content="This is another valid document for testing resilience.",
metadata={"status": "valid"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="valid document",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "valid document" in content_text
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
assert len(providers) > 0
@ -249,3 +353,107 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
},
)
def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_query_generation_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
documents = [
Document(
document_id="ai-doc",
content="Artificial intelligence and machine learning are transforming technology.",
metadata={"category": "AI"},
),
Document(
document_id="banana-doc",
content="Don't bring a banana to a knife fight.",
metadata={"category": "wisdom"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="Tell me about AI",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "artificial intelligence" in content_text or "machine learning" in content_text
def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_pdf_data_url_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
actual_vector_db_id = available_vector_dbs[0]
sample_pdf = b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
import base64
pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
documents = [
Document(
document_id="test-pdf-data-url",
content=pdf_data_url,
metadata={"type": "pdf", "source": "data_url"},
),
]
client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents,
vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=256,
)
files_list = client_with_empty_registry.files.list()
assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
pdf_file = None
for file in files_list.data:
if file.filename and "test-pdf-data-url" in file.filename:
pdf_file = file
break
assert pdf_file is not None, "PDF file should be found in Files API"
assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
vector_store_id=actual_vector_db_id
)
assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
response = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[actual_vector_db_id],
content="sample title",
)
assert_valid_text_response(response)
content_text = " ".join([chunk.text for chunk in response.content]).lower()
assert "sample title" in content_text or "title" in content_text

View file

@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
# Should raise an exception instead of returning empty string
with pytest.raises(UnicodeDecodeError):
content_from_data_and_mime_type(data, mime_type)
async def test_memory_tool_error_handling():
"""Test that memory tool handles various failures gracefully without crashing."""
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
config = RagToolRuntimeConfig()
memory_tool = MemoryToolRuntimeImpl(
config=config,
vector_io_api=AsyncMock(),
inference_api=AsyncMock(),
files_api=AsyncMock(),
)
docs = [
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
]
mock_file1 = MagicMock()
mock_file1.id = "file_good1"
mock_file2 = MagicMock()
mock_file2.id = "file_good2"
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
with patch("httpx.AsyncClient") as mock_client:
mock_instance = AsyncMock()
mock_instance.get.side_effect = Exception("Bad URL")
mock_client.return_value.__aenter__.return_value = mock_instance
# won't raise exception despite one document failing
await memory_tool.insert(docs, "vector_store_123")
# processed 2 documents successfully, skipped 1
assert memory_tool.files_api.openai_upload_file.call_count == 2
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2

View file

@ -81,3 +81,58 @@ class TestRagQuery:
# Test that invalid mode raises an error
with pytest.raises(ValueError):
RAGQueryConfig(mode="wrong_mode")
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(),
vector_io_api=MagicMock(),
inference_api=MagicMock(),
files_api=MagicMock(),
)
vector_db_ids = ["db1", "db2"]
# Fake chunks from each DB
chunk_metadata1 = ChunkMetadata(
document_id="doc1",
chunk_id="chunk1",
source="test_source1",
metadata_token_count=5,
)
chunk1 = Chunk(
content="chunk from db1",
metadata={"vector_db_id": "db1", "document_id": "doc1"},
stored_chunk_id="c1",
chunk_metadata=chunk_metadata1,
)
chunk_metadata2 = ChunkMetadata(
document_id="doc2",
chunk_id="chunk2",
source="test_source2",
metadata_token_count=5,
)
chunk2 = Chunk(
content="chunk from db2",
metadata={"vector_db_id": "db2", "document_id": "doc2"},
stored_chunk_id="c2",
chunk_metadata=chunk_metadata2,
)
rag_tool.vector_io_api.query_chunks = AsyncMock(
side_effect=[
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
]
)
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
returned_chunks = result.metadata["chunks"]
returned_scores = result.metadata["scores"]
returned_doc_ids = result.metadata["document_ids"]
returned_vector_db_ids = result.metadata["vector_db_ids"]
assert returned_chunks == ["chunk from db1", "chunk from db2"]
assert returned_scores == (0.9, 0.8)
assert returned_doc_ids == ["doc1", "doc2"]
assert returned_vector_db_ids == ["db1", "db2"]