mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 12:07:34 +00:00
chore: Updating documentation and adding exception handling for Vector Stores in RAG Tool and updating inference to use openai and updating memory implementation to use existing libraries
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
28696c3f30
commit
ff0bd414b1
27 changed files with 926 additions and 403 deletions
42
.github/actions/run-and-record-tests/action.yml
vendored
42
.github/actions/run-and-record-tests/action.yml
vendored
|
@ -5,21 +5,22 @@ inputs:
|
||||||
stack-config:
|
stack-config:
|
||||||
description: 'Stack configuration to use'
|
description: 'Stack configuration to use'
|
||||||
required: true
|
required: true
|
||||||
provider:
|
setup:
|
||||||
description: 'Provider to use for tests'
|
description: 'Setup to use for tests (e.g., ollama, gpt, vllm)'
|
||||||
required: true
|
required: false
|
||||||
|
default: ''
|
||||||
inference-mode:
|
inference-mode:
|
||||||
description: 'Inference mode (record or replay)'
|
description: 'Inference mode (record or replay)'
|
||||||
required: true
|
required: true
|
||||||
test-suite:
|
suite:
|
||||||
description: 'Test suite to use: base, responses, vision, etc.'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
test-subdirs:
|
subdirs:
|
||||||
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
|
description: 'Comma-separated list of test subdirectories to run; overrides suite'
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
test-pattern:
|
pattern:
|
||||||
description: 'Regex pattern to pass to pytest -k'
|
description: 'Regex pattern to pass to pytest -k'
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
|
@ -37,14 +38,23 @@ runs:
|
||||||
- name: Run Integration Tests
|
- name: Run Integration Tests
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
uv run --no-sync ./scripts/integration-tests.sh \
|
SCRIPT_ARGS="--stack-config ${{ inputs.stack-config }} --inference-mode ${{ inputs.inference-mode }}"
|
||||||
--stack-config '${{ inputs.stack-config }}' \
|
|
||||||
--provider '${{ inputs.provider }}' \
|
# Add optional arguments only if they are provided
|
||||||
--test-subdirs '${{ inputs.test-subdirs }}' \
|
if [ -n '${{ inputs.setup }}' ]; then
|
||||||
--test-pattern '${{ inputs.test-pattern }}' \
|
SCRIPT_ARGS="$SCRIPT_ARGS --setup ${{ inputs.setup }}"
|
||||||
--inference-mode '${{ inputs.inference-mode }}' \
|
fi
|
||||||
--test-suite '${{ inputs.test-suite }}' \
|
if [ -n '${{ inputs.suite }}' ]; then
|
||||||
| tee pytest-${{ inputs.inference-mode }}.log
|
SCRIPT_ARGS="$SCRIPT_ARGS --suite ${{ inputs.suite }}"
|
||||||
|
fi
|
||||||
|
if [ -n '${{ inputs.subdirs }}' ]; then
|
||||||
|
SCRIPT_ARGS="$SCRIPT_ARGS --subdirs ${{ inputs.subdirs }}"
|
||||||
|
fi
|
||||||
|
if [ -n '${{ inputs.pattern }}' ]; then
|
||||||
|
SCRIPT_ARGS="$SCRIPT_ARGS --pattern ${{ inputs.pattern }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
uv run --no-sync ./scripts/integration-tests.sh $SCRIPT_ARGS | tee pytest-${{ inputs.inference-mode }}.log
|
||||||
|
|
||||||
|
|
||||||
- name: Commit and push recordings
|
- name: Commit and push recordings
|
||||||
|
@ -58,7 +68,7 @@ runs:
|
||||||
echo "New recordings detected, committing and pushing"
|
echo "New recordings detected, committing and pushing"
|
||||||
git add tests/integration/recordings/
|
git add tests/integration/recordings/
|
||||||
|
|
||||||
git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})"
|
git commit -m "Recordings update from CI (suite: ${{ inputs.suite }})"
|
||||||
git fetch origin ${{ github.ref_name }}
|
git fetch origin ${{ github.ref_name }}
|
||||||
git rebase origin/${{ github.ref_name }}
|
git rebase origin/${{ github.ref_name }}
|
||||||
echo "Rebased successfully"
|
echo "Rebased successfully"
|
||||||
|
|
4
.github/actions/setup-ollama/action.yml
vendored
4
.github/actions/setup-ollama/action.yml
vendored
|
@ -1,7 +1,7 @@
|
||||||
name: Setup Ollama
|
name: Setup Ollama
|
||||||
description: Start Ollama
|
description: Start Ollama
|
||||||
inputs:
|
inputs:
|
||||||
test-suite:
|
suite:
|
||||||
description: 'Test suite to use: base, responses, vision, etc.'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
|
@ -11,7 +11,7 @@ runs:
|
||||||
- name: Start Ollama
|
- name: Start Ollama
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
if [ "${{ inputs.test-suite }}" == "vision" ]; then
|
if [ "${{ inputs.suite }}" == "vision" ]; then
|
||||||
image="ollama-with-vision-model"
|
image="ollama-with-vision-model"
|
||||||
else
|
else
|
||||||
image="ollama-with-models"
|
image="ollama-with-models"
|
||||||
|
|
|
@ -8,11 +8,11 @@ inputs:
|
||||||
client-version:
|
client-version:
|
||||||
description: 'Client version (latest or published)'
|
description: 'Client version (latest or published)'
|
||||||
required: true
|
required: true
|
||||||
provider:
|
setup:
|
||||||
description: 'Provider to setup (ollama or vllm)'
|
description: 'Setup to configure (ollama, vllm, gpt, etc.)'
|
||||||
required: true
|
required: false
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
test-suite:
|
suite:
|
||||||
description: 'Test suite to use: base, responses, vision, etc.'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
required: false
|
required: false
|
||||||
default: ''
|
default: ''
|
||||||
|
@ -30,13 +30,13 @@ runs:
|
||||||
client-version: ${{ inputs.client-version }}
|
client-version: ${{ inputs.client-version }}
|
||||||
|
|
||||||
- name: Setup ollama
|
- name: Setup ollama
|
||||||
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
|
if: ${{ (inputs.setup == 'ollama' || inputs.setup == 'ollama-vision') && inputs.inference-mode == 'record' }}
|
||||||
uses: ./.github/actions/setup-ollama
|
uses: ./.github/actions/setup-ollama
|
||||||
with:
|
with:
|
||||||
test-suite: ${{ inputs.test-suite }}
|
suite: ${{ inputs.suite }}
|
||||||
|
|
||||||
- name: Setup vllm
|
- name: Setup vllm
|
||||||
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}
|
if: ${{ inputs.setup == 'vllm' && inputs.inference-mode == 'record' }}
|
||||||
uses: ./.github/actions/setup-vllm
|
uses: ./.github/actions/setup-vllm
|
||||||
|
|
||||||
- name: Build Llama Stack
|
- name: Build Llama Stack
|
||||||
|
|
20
.github/workflows/integration-tests.yml
vendored
20
.github/workflows/integration-tests.yml
vendored
|
@ -28,8 +28,8 @@ on:
|
||||||
description: 'Test against both the latest and published versions'
|
description: 'Test against both the latest and published versions'
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
test-provider:
|
test-setup:
|
||||||
description: 'Test against a specific provider'
|
description: 'Test against a specific setup'
|
||||||
type: string
|
type: string
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
|
|
||||||
|
@ -42,18 +42,18 @@ jobs:
|
||||||
|
|
||||||
run-replay-mode-tests:
|
run-replay-mode-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }}
|
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.setup, matrix.python-version, matrix.client-version, matrix.suite) }}
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
client-type: [library, server]
|
client-type: [library, server]
|
||||||
# Use vllm on weekly schedule, otherwise use test-provider input (defaults to ollama)
|
# Use vllm on weekly schedule, otherwise use test-setup input (defaults to ollama)
|
||||||
provider: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-provider || 'ollama')) }}
|
setup: ${{ (github.event.schedule == '1 0 * * 0') && fromJSON('["vllm"]') || fromJSON(format('["{0}"]', github.event.inputs.test-setup || 'ollama')) }}
|
||||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||||
test-suite: [base, vision]
|
suite: [base, vision]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
@ -64,14 +64,14 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
client-version: ${{ matrix.client-version }}
|
client-version: ${{ matrix.client-version }}
|
||||||
provider: ${{ matrix.provider }}
|
setup: ${{ matrix.setup }}
|
||||||
test-suite: ${{ matrix.test-suite }}
|
suite: ${{ matrix.suite }}
|
||||||
inference-mode: 'replay'
|
inference-mode: 'replay'
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
uses: ./.github/actions/run-and-record-tests
|
uses: ./.github/actions/run-and-record-tests
|
||||||
with:
|
with:
|
||||||
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
||||||
provider: ${{ matrix.provider }}
|
setup: ${{ matrix.setup }}
|
||||||
inference-mode: 'replay'
|
inference-mode: 'replay'
|
||||||
test-suite: ${{ matrix.test-suite }}
|
suite: ${{ matrix.suite }}
|
||||||
|
|
1
.github/workflows/pre-commit.yml
vendored
1
.github/workflows/pre-commit.yml
vendored
|
@ -48,7 +48,6 @@ jobs:
|
||||||
working-directory: llama_stack/ui
|
working-directory: llama_stack/ui
|
||||||
|
|
||||||
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
- uses: pre-commit/action@2c7b3805fd2a0fd8c1884dcaebf91fc102a13ecd # v3.0.1
|
||||||
continue-on-error: true
|
|
||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
RUFF_OUTPUT_FORMAT: github
|
RUFF_OUTPUT_FORMAT: github
|
||||||
|
|
32
.github/workflows/record-integration-tests.yml
vendored
32
.github/workflows/record-integration-tests.yml
vendored
|
@ -10,19 +10,19 @@ run-name: Run the integration test suite from tests/integration
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
test-provider:
|
test-setup:
|
||||||
description: 'Test against a specific provider'
|
description: 'Test against a specific setup'
|
||||||
type: string
|
type: string
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
test-suite:
|
suite:
|
||||||
description: 'Test suite to use: base, responses, vision, etc.'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
type: string
|
type: string
|
||||||
default: ''
|
default: ''
|
||||||
test-subdirs:
|
subdirs:
|
||||||
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
|
description: 'Comma-separated list of test subdirectories to run; overrides suite'
|
||||||
type: string
|
type: string
|
||||||
default: ''
|
default: ''
|
||||||
test-pattern:
|
pattern:
|
||||||
description: 'Regex pattern to pass to pytest -k'
|
description: 'Regex pattern to pass to pytest -k'
|
||||||
type: string
|
type: string
|
||||||
default: ''
|
default: ''
|
||||||
|
@ -39,10 +39,10 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
echo "::group::Workflow Inputs"
|
echo "::group::Workflow Inputs"
|
||||||
echo "branch: ${{ github.ref_name }}"
|
echo "branch: ${{ github.ref_name }}"
|
||||||
echo "test-provider: ${{ inputs.test-provider }}"
|
echo "test-setup: ${{ inputs.test-setup }}"
|
||||||
echo "test-suite: ${{ inputs.test-suite }}"
|
echo "suite: ${{ inputs.suite }}"
|
||||||
echo "test-subdirs: ${{ inputs.test-subdirs }}"
|
echo "subdirs: ${{ inputs.subdirs }}"
|
||||||
echo "test-pattern: ${{ inputs.test-pattern }}"
|
echo "pattern: ${{ inputs.pattern }}"
|
||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
@ -55,16 +55,16 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: "3.12" # Use single Python version for recording
|
python-version: "3.12" # Use single Python version for recording
|
||||||
client-version: "latest"
|
client-version: "latest"
|
||||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
setup: ${{ inputs.test-setup || 'ollama' }}
|
||||||
test-suite: ${{ inputs.test-suite }}
|
suite: ${{ inputs.suite }}
|
||||||
inference-mode: 'record'
|
inference-mode: 'record'
|
||||||
|
|
||||||
- name: Run and record tests
|
- name: Run and record tests
|
||||||
uses: ./.github/actions/run-and-record-tests
|
uses: ./.github/actions/run-and-record-tests
|
||||||
with:
|
with:
|
||||||
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
||||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
setup: ${{ inputs.test-setup || 'ollama' }}
|
||||||
inference-mode: 'record'
|
inference-mode: 'record'
|
||||||
test-suite: ${{ inputs.test-suite }}
|
suite: ${{ inputs.suite }}
|
||||||
test-subdirs: ${{ inputs.test-subdirs }}
|
subdirs: ${{ inputs.subdirs }}
|
||||||
test-pattern: ${{ inputs.test-pattern }}
|
pattern: ${{ inputs.pattern }}
|
||||||
|
|
|
@ -93,10 +93,31 @@ chunks_response = client.vector_io.query(
|
||||||
|
|
||||||
### Using the RAG Tool
|
### Using the RAG Tool
|
||||||
|
|
||||||
|
> **⚠️ DEPRECATION NOTICE**: The RAG Tool is being deprecated in favor of directly using the OpenAI-compatible Search
|
||||||
|
> API. We recommend migrating to the OpenAI APIs for better compatibility and future support.
|
||||||
|
|
||||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
||||||
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
||||||
[appendix](#more-ragdocument-examples).
|
[appendix](#more-ragdocument-examples).
|
||||||
|
|
||||||
|
#### OpenAI API Integration & Migration
|
||||||
|
|
||||||
|
The RAG tool has been updated to use OpenAI-compatible APIs. This provides several benefits:
|
||||||
|
|
||||||
|
- **Files API Integration**: Documents are now uploaded using OpenAI's file upload endpoints
|
||||||
|
- **Vector Stores API**: Vector storage operations use OpenAI's vector store format with configurable chunking strategies
|
||||||
|
- **Error Resilience:** When processing multiple documents, individual failures are logged but don't crash the operation. Failed documents are skipped while successful ones continue processing.
|
||||||
|
|
||||||
|
**Migration Path:**
|
||||||
|
We recommend migrating to the OpenAI-compatible Search API for:
|
||||||
|
1. **Better OpenAI Ecosystem Integration**: Direct compatibility with OpenAI tools and workflows including the Responses API
|
||||||
|
2**Future-Proof**: Continued support and feature development
|
||||||
|
3**Full OpenAI Compatibility**: Vector Stores, Files, and Search APIs are fully compatible with OpenAI's Responses API
|
||||||
|
|
||||||
|
The OpenAI APIs are used under the hood, so you can continue to use your existing RAG Tool code with minimal changes.
|
||||||
|
However, we recommend updating your code to use the new OpenAI-compatible APIs for better long-term support. If any
|
||||||
|
documents fail to process, they will be logged in the response but will not cause the entire operation to fail.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import RAGDocument
|
from llama_stack_client import RAGDocument
|
||||||
|
|
||||||
|
|
|
@ -6,6 +6,7 @@ data:
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
- inference
|
- inference
|
||||||
|
- files
|
||||||
- safety
|
- safety
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
|
@ -19,13 +20,6 @@ data:
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
api_token: ${env.VLLM_API_TOKEN:=fake}
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
||||||
- provider_id: vllm-safety
|
|
||||||
provider_type: remote::vllm
|
|
||||||
config:
|
|
||||||
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
|
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
|
||||||
- provider_id: sentence-transformers
|
- provider_id: sentence-transformers
|
||||||
provider_type: inline::sentence-transformers
|
provider_type: inline::sentence-transformers
|
||||||
config: {}
|
config: {}
|
||||||
|
@ -41,6 +35,14 @@ data:
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
db: ${env.POSTGRES_DB:=llamastack}
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
user: ${env.POSTGRES_USER:=llamastack}
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
@ -111,9 +113,6 @@ data:
|
||||||
- model_id: ${env.INFERENCE_MODEL}
|
- model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: vllm-inference
|
provider_id: vllm-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
- model_id: ${env.SAFETY_MODEL}
|
|
||||||
provider_id: vllm-safety
|
|
||||||
model_type: llm
|
|
||||||
shields:
|
shields:
|
||||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
|
|
|
@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
- inference
|
- inference
|
||||||
|
- files
|
||||||
- safety
|
- safety
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
|
@ -31,6 +32,14 @@ providers:
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
db: ${env.POSTGRES_DB:=llamastack}
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
user: ${env.POSTGRES_USER:=llamastack}
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -1,137 +1,55 @@
|
||||||
apiVersion: v1
|
apiVersion: v1
|
||||||
data:
|
data:
|
||||||
stack_run_config.yaml: |
|
stack_run_config.yaml: "version: '2'\nimage_name: kubernetes-demo\napis:\n- agents\n-
|
||||||
version: '2'
|
inference\n- files\n- safety\n- telemetry\n- tool_runtime\n- vector_io\nproviders:\n
|
||||||
image_name: kubernetes-demo
|
\ inference:\n - provider_id: vllm-inference\n provider_type: remote::vllm\n
|
||||||
apis:
|
\ config:\n url: ${env.VLLM_URL:=http://localhost:8000/v1}\n max_tokens:
|
||||||
- agents
|
${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n tls_verify:
|
||||||
- inference
|
${env.VLLM_TLS_VERIFY:=true}\n - provider_id: vllm-safety\n provider_type:
|
||||||
- safety
|
remote::vllm\n config:\n url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}\n
|
||||||
- telemetry
|
\ max_tokens: ${env.VLLM_MAX_TOKENS:=4096}\n api_token: ${env.VLLM_API_TOKEN:=fake}\n
|
||||||
- tool_runtime
|
\ tls_verify: ${env.VLLM_TLS_VERIFY:=true}\n - provider_id: sentence-transformers\n
|
||||||
- vector_io
|
\ provider_type: inline::sentence-transformers\n config: {}\n vector_io:\n
|
||||||
providers:
|
\ - provider_id: ${env.ENABLE_CHROMADB:+chromadb}\n provider_type: remote::chromadb\n
|
||||||
inference:
|
\ config:\n url: ${env.CHROMADB_URL:=}\n kvstore:\n type: postgres\n
|
||||||
- provider_id: vllm-inference
|
\ host: ${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n
|
||||||
provider_type: remote::vllm
|
\ db: ${env.POSTGRES_DB:=llamastack}\n user: ${env.POSTGRES_USER:=llamastack}\n
|
||||||
config:
|
\ password: ${env.POSTGRES_PASSWORD:=llamastack}\n files:\n - provider_id:
|
||||||
url: ${env.VLLM_URL:=http://localhost:8000/v1}
|
meta-reference-files\n provider_type: inline::localfs\n config:\n storage_dir:
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}\n metadata_store:\n
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
\ type: sqlite\n db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
\ \n safety:\n - provider_id: llama-guard\n provider_type: inline::llama-guard\n
|
||||||
- provider_id: vllm-safety
|
\ config:\n excluded_categories: []\n agents:\n - provider_id: meta-reference\n
|
||||||
provider_type: remote::vllm
|
\ provider_type: inline::meta-reference\n config:\n persistence_store:\n
|
||||||
config:
|
\ type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n port:
|
||||||
url: ${env.VLLM_SAFETY_URL:=http://localhost:8000/v1}
|
${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:=4096}
|
${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
|
||||||
api_token: ${env.VLLM_API_TOKEN:=fake}
|
\ responses_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
|
||||||
tls_verify: ${env.VLLM_TLS_VERIFY:=true}
|
\ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
|
||||||
- provider_id: sentence-transformers
|
\ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
|
||||||
provider_type: inline::sentence-transformers
|
\ telemetry:\n - provider_id: meta-reference\n provider_type: inline::meta-reference\n
|
||||||
config: {}
|
\ config:\n service_name: \"${env.OTEL_SERVICE_NAME:=\\u200B}\"\n sinks:
|
||||||
vector_io:
|
${env.TELEMETRY_SINKS:=console}\n tool_runtime:\n - provider_id: brave-search\n
|
||||||
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
|
\ provider_type: remote::brave-search\n config:\n api_key: ${env.BRAVE_SEARCH_API_KEY:+}\n
|
||||||
provider_type: remote::chromadb
|
\ max_results: 3\n - provider_id: tavily-search\n provider_type: remote::tavily-search\n
|
||||||
config:
|
\ config:\n api_key: ${env.TAVILY_SEARCH_API_KEY:+}\n max_results:
|
||||||
url: ${env.CHROMADB_URL:=}
|
3\n - provider_id: rag-runtime\n provider_type: inline::rag-runtime\n config:
|
||||||
kvstore:
|
{}\n - provider_id: model-context-protocol\n provider_type: remote::model-context-protocol\n
|
||||||
type: postgres
|
\ config: {}\nmetadata_store:\n type: postgres\n host: ${env.POSTGRES_HOST:=localhost}\n
|
||||||
host: ${env.POSTGRES_HOST:=localhost}
|
\ port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n user:
|
||||||
port: ${env.POSTGRES_PORT:=5432}
|
${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\n
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
\ table_name: llamastack_kvstore\ninference_store:\n type: postgres\n host:
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
${env.POSTGRES_HOST:=localhost}\n port: ${env.POSTGRES_PORT:=5432}\n db: ${env.POSTGRES_DB:=llamastack}\n
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
\ user: ${env.POSTGRES_USER:=llamastack}\n password: ${env.POSTGRES_PASSWORD:=llamastack}\nmodels:\n-
|
||||||
safety:
|
metadata:\n embedding_dimension: 384\n model_id: all-MiniLM-L6-v2\n provider_id:
|
||||||
- provider_id: llama-guard
|
sentence-transformers\n model_type: embedding\n- metadata: {}\n model_id: ${env.INFERENCE_MODEL}\n
|
||||||
provider_type: inline::llama-guard
|
\ provider_id: vllm-inference\n model_type: llm\n- metadata: {}\n model_id:
|
||||||
config:
|
${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\n provider_id: vllm-safety\n
|
||||||
excluded_categories: []
|
\ model_type: llm\nshields:\n- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}\nvector_dbs:
|
||||||
agents:
|
[]\ndatasets: []\nscoring_fns: []\nbenchmarks: []\ntool_groups:\n- toolgroup_id:
|
||||||
- provider_id: meta-reference
|
builtin::websearch\n provider_id: tavily-search\n- toolgroup_id: builtin::rag\n
|
||||||
provider_type: inline::meta-reference
|
\ provider_id: rag-runtime\nserver:\n port: 8321\n auth:\n provider_config:\n
|
||||||
config:
|
\ type: github_token\n"
|
||||||
persistence_store:
|
|
||||||
type: postgres
|
|
||||||
host: ${env.POSTGRES_HOST:=localhost}
|
|
||||||
port: ${env.POSTGRES_PORT:=5432}
|
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
|
||||||
responses_store:
|
|
||||||
type: postgres
|
|
||||||
host: ${env.POSTGRES_HOST:=localhost}
|
|
||||||
port: ${env.POSTGRES_PORT:=5432}
|
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
|
||||||
telemetry:
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: inline::meta-reference
|
|
||||||
config:
|
|
||||||
service_name: "${env.OTEL_SERVICE_NAME:=\u200B}"
|
|
||||||
sinks: ${env.TELEMETRY_SINKS:=console}
|
|
||||||
tool_runtime:
|
|
||||||
- provider_id: brave-search
|
|
||||||
provider_type: remote::brave-search
|
|
||||||
config:
|
|
||||||
api_key: ${env.BRAVE_SEARCH_API_KEY:+}
|
|
||||||
max_results: 3
|
|
||||||
- provider_id: tavily-search
|
|
||||||
provider_type: remote::tavily-search
|
|
||||||
config:
|
|
||||||
api_key: ${env.TAVILY_SEARCH_API_KEY:+}
|
|
||||||
max_results: 3
|
|
||||||
- provider_id: rag-runtime
|
|
||||||
provider_type: inline::rag-runtime
|
|
||||||
config: {}
|
|
||||||
- provider_id: model-context-protocol
|
|
||||||
provider_type: remote::model-context-protocol
|
|
||||||
config: {}
|
|
||||||
metadata_store:
|
|
||||||
type: postgres
|
|
||||||
host: ${env.POSTGRES_HOST:=localhost}
|
|
||||||
port: ${env.POSTGRES_PORT:=5432}
|
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
|
||||||
table_name: llamastack_kvstore
|
|
||||||
inference_store:
|
|
||||||
type: postgres
|
|
||||||
host: ${env.POSTGRES_HOST:=localhost}
|
|
||||||
port: ${env.POSTGRES_PORT:=5432}
|
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
|
||||||
models:
|
|
||||||
- metadata:
|
|
||||||
embedding_dimension: 384
|
|
||||||
model_id: all-MiniLM-L6-v2
|
|
||||||
provider_id: sentence-transformers
|
|
||||||
model_type: embedding
|
|
||||||
- metadata: {}
|
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
|
||||||
provider_id: vllm-inference
|
|
||||||
model_type: llm
|
|
||||||
- metadata: {}
|
|
||||||
model_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
|
||||||
provider_id: vllm-safety
|
|
||||||
model_type: llm
|
|
||||||
shields:
|
|
||||||
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
|
||||||
vector_dbs: []
|
|
||||||
datasets: []
|
|
||||||
scoring_fns: []
|
|
||||||
benchmarks: []
|
|
||||||
tool_groups:
|
|
||||||
- toolgroup_id: builtin::websearch
|
|
||||||
provider_id: tavily-search
|
|
||||||
- toolgroup_id: builtin::rag
|
|
||||||
provider_id: rag-runtime
|
|
||||||
server:
|
|
||||||
port: 8321
|
|
||||||
auth:
|
|
||||||
provider_config:
|
|
||||||
type: github_token
|
|
||||||
kind: ConfigMap
|
kind: ConfigMap
|
||||||
metadata:
|
metadata:
|
||||||
creationTimestamp: null
|
creationTimestamp: null
|
||||||
|
|
|
@ -3,6 +3,7 @@ image_name: kubernetes-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
- inference
|
- inference
|
||||||
|
- files
|
||||||
- safety
|
- safety
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
|
@ -38,6 +39,14 @@ providers:
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
db: ${env.POSTGRES_DB:=llamastack}
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
user: ${env.POSTGRES_USER:=llamastack}
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||||
|
files:
|
||||||
|
- provider_id: meta-reference-files
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -45,6 +45,7 @@ from llama_stack.core.utils.dynamic import instantiate_class_type
|
||||||
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
from llama_stack.core.utils.exec import formulate_run_args, run_command
|
||||||
from llama_stack.core.utils.image_types import LlamaStackImageType
|
from llama_stack.core.utils.image_types import LlamaStackImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
DISTRIBS_PATH = Path(__file__).parent.parent.parent / "distributions"
|
||||||
|
|
||||||
|
@ -294,6 +295,12 @@ def _generate_run_config(
|
||||||
if build_config.external_providers_dir
|
if build_config.external_providers_dir
|
||||||
else EXTERNAL_PROVIDERS_DIR,
|
else EXTERNAL_PROVIDERS_DIR,
|
||||||
)
|
)
|
||||||
|
if not run_config.inference_store:
|
||||||
|
run_config.inference_store = SqliteSqlStoreConfig(
|
||||||
|
**SqliteSqlStoreConfig.sample_run_config(
|
||||||
|
__distro_dir__=(DISTRIBS_BASE_DIR / image_name).as_posix(), db_name="inference_store.db"
|
||||||
|
)
|
||||||
|
)
|
||||||
# build providers dict
|
# build providers dict
|
||||||
provider_registry = get_provider_registry(build_config)
|
provider_registry = get_provider_registry(build_config)
|
||||||
for api in apis:
|
for api in apis:
|
||||||
|
|
|
@ -10,7 +10,6 @@ import json
|
||||||
import logging # allow-direct-logging
|
import logging # allow-direct-logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -148,7 +147,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
|
||||||
self.async_client = AsyncLlamaStackAsLibraryClient(
|
self.async_client = AsyncLlamaStackAsLibraryClient(
|
||||||
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
|
||||||
)
|
)
|
||||||
self.pool_executor = ThreadPoolExecutor(max_workers=4)
|
|
||||||
self.provider_data = provider_data
|
self.provider_data = provider_data
|
||||||
|
|
||||||
self.loop = asyncio.new_event_loop()
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
|
@ -8,7 +8,7 @@
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import InterleavedContent
|
from llama_stack.apis.common.content_types import InterleavedContent
|
||||||
from llama_stack.apis.inference import UserMessage
|
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||||
from llama_stack.apis.tools.rag_tool import (
|
from llama_stack.apis.tools.rag_tool import (
|
||||||
DefaultRAGQueryGeneratorConfig,
|
DefaultRAGQueryGeneratorConfig,
|
||||||
LLMRAGQueryGeneratorConfig,
|
LLMRAGQueryGeneratorConfig,
|
||||||
|
@ -61,16 +61,16 @@ async def llm_rag_query_generator(
|
||||||
messages = [interleaved_content_as_str(content)]
|
messages = [interleaved_content_as_str(content)]
|
||||||
|
|
||||||
template = Template(config.template)
|
template = Template(config.template)
|
||||||
content = template.render({"messages": messages})
|
rendered_content: str = template.render({"messages": messages})
|
||||||
|
|
||||||
model = config.model
|
model = config.model
|
||||||
message = UserMessage(content=content)
|
message = OpenAIUserMessageParam(content=rendered_content)
|
||||||
response = await inference_api.chat_completion(
|
response = await inference_api.openai_chat_completion(
|
||||||
model_id=model,
|
model=model,
|
||||||
messages=[message],
|
messages=[message],
|
||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
query = response.completion_message.content
|
query = response.choices[0].message.content
|
||||||
|
|
||||||
return query
|
return query
|
||||||
|
|
|
@ -45,10 +45,7 @@ from llama_stack.apis.vector_io import (
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||||
content_from_doc,
|
|
||||||
parse_data_url,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .config import RagToolRuntimeConfig
|
from .config import RagToolRuntimeConfig
|
||||||
from .context_retriever import generate_rag_query
|
from .context_retriever import generate_rag_query
|
||||||
|
@ -60,6 +57,47 @@ def make_random_string(length: int = 8):
|
||||||
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length))
|
||||||
|
|
||||||
|
|
||||||
|
async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
||||||
|
"""Get raw binary data and mime type from a RAGDocument for file upload."""
|
||||||
|
if isinstance(doc.content, URL):
|
||||||
|
if doc.content.uri.startswith("data:"):
|
||||||
|
parts = parse_data_url(doc.content.uri)
|
||||||
|
mime_type = parts["mimetype"]
|
||||||
|
data = parts["data"]
|
||||||
|
|
||||||
|
if parts["is_base64"]:
|
||||||
|
file_data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
file_data = data.encode("utf-8")
|
||||||
|
|
||||||
|
return file_data, mime_type
|
||||||
|
else:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
r = await client.get(doc.content.uri)
|
||||||
|
r.raise_for_status()
|
||||||
|
mime_type = r.headers.get("content-type", "application/octet-stream")
|
||||||
|
return r.content, mime_type
|
||||||
|
else:
|
||||||
|
if isinstance(doc.content, str):
|
||||||
|
content_str = doc.content
|
||||||
|
else:
|
||||||
|
content_str = interleaved_content_as_str(doc.content)
|
||||||
|
|
||||||
|
if content_str.startswith("data:"):
|
||||||
|
parts = parse_data_url(content_str)
|
||||||
|
mime_type = parts["mimetype"]
|
||||||
|
data = parts["data"]
|
||||||
|
|
||||||
|
if parts["is_base64"]:
|
||||||
|
file_data = base64.b64decode(data)
|
||||||
|
else:
|
||||||
|
file_data = data.encode("utf-8")
|
||||||
|
|
||||||
|
return file_data, mime_type
|
||||||
|
else:
|
||||||
|
return content_str.encode("utf-8"), "text/plain"
|
||||||
|
|
||||||
|
|
||||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -95,20 +133,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
return
|
return
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
if isinstance(doc.content, URL):
|
try:
|
||||||
if doc.content.uri.startswith("data:"):
|
try:
|
||||||
parts = parse_data_url(doc.content.uri)
|
file_data, mime_type = await raw_data_from_doc(doc)
|
||||||
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
except Exception as e:
|
||||||
mime_type = parts["mimetype"]
|
log.error(f"Failed to extract content from document {doc.document_id}: {e}")
|
||||||
else:
|
continue
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(doc.content.uri)
|
|
||||||
file_data = response.content
|
|
||||||
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
|
||||||
else:
|
|
||||||
content_str = await content_from_doc(doc)
|
|
||||||
file_data = content_str.encode("utf-8")
|
|
||||||
mime_type = doc.mime_type or "text/plain"
|
|
||||||
|
|
||||||
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||||
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||||
|
@ -118,9 +148,13 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
|
|
||||||
upload_file = UploadFile(file=file_obj, filename=filename)
|
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||||
|
|
||||||
|
try:
|
||||||
created_file = await self.files_api.openai_upload_file(
|
created_file = await self.files_api.openai_upload_file(
|
||||||
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Failed to upload file for document {doc.document_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||||
static=VectorStoreChunkingStrategyStaticConfig(
|
static=VectorStoreChunkingStrategyStaticConfig(
|
||||||
|
@ -129,12 +163,22 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
await self.vector_io_api.openai_attach_file_to_vector_store(
|
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||||
vector_store_id=vector_db_id,
|
vector_store_id=vector_db_id,
|
||||||
file_id=created_file.id,
|
file_id=created_file.id,
|
||||||
attributes=doc.metadata,
|
attributes=doc.metadata,
|
||||||
chunking_strategy=chunking_strategy,
|
chunking_strategy=chunking_strategy,
|
||||||
)
|
)
|
||||||
|
except Exception as e:
|
||||||
|
log.error(
|
||||||
|
f"Failed to attach file {created_file.id} to vector store {vector_db_id} for document {doc.document_id}: {e}"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
log.error(f"Unexpected error processing document {doc.document_id}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
|
@ -167,8 +211,18 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
for vector_db_id in vector_db_ids
|
for vector_db_id in vector_db_ids
|
||||||
]
|
]
|
||||||
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
results: list[QueryChunksResponse] = await asyncio.gather(*tasks)
|
||||||
chunks = [c for r in results for c in r.chunks]
|
|
||||||
scores = [s for r in results for s in r.scores]
|
chunks = []
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for vector_db_id, result in zip(vector_db_ids, results, strict=False):
|
||||||
|
for chunk, score in zip(result.chunks, result.scores, strict=False):
|
||||||
|
if not hasattr(chunk, "metadata") or chunk.metadata is None:
|
||||||
|
chunk.metadata = {}
|
||||||
|
chunk.metadata["vector_db_id"] = vector_db_id
|
||||||
|
|
||||||
|
chunks.append(chunk)
|
||||||
|
scores.append(score)
|
||||||
|
|
||||||
if not chunks:
|
if not chunks:
|
||||||
return RAGQueryResult(content=None)
|
return RAGQueryResult(content=None)
|
||||||
|
@ -203,6 +257,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
metadata_keys_to_exclude_from_context = [
|
metadata_keys_to_exclude_from_context = [
|
||||||
"token_count",
|
"token_count",
|
||||||
"metadata_token_count",
|
"metadata_token_count",
|
||||||
|
"vector_db_id",
|
||||||
]
|
]
|
||||||
metadata_for_context = {}
|
metadata_for_context = {}
|
||||||
for k in chunk_metadata_keys_to_include_from_context:
|
for k in chunk_metadata_keys_to_include_from_context:
|
||||||
|
@ -227,6 +282,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
"document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]],
|
||||||
"chunks": [c.content for c in chunks[: len(picked)]],
|
"chunks": [c.content for c in chunks[: len(picked)]],
|
||||||
"scores": scores[: len(picked)],
|
"scores": scores[: len(picked)],
|
||||||
|
"vector_db_ids": [c.metadata["vector_db_id"] for c in chunks[: len(picked)]],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -262,7 +318,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
if query_config:
|
if query_config:
|
||||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||||
else:
|
else:
|
||||||
# handle someone passing an empty dict
|
|
||||||
query_config = RAGQueryConfig()
|
query_config = RAGQueryConfig()
|
||||||
|
|
||||||
query = kwargs["query"]
|
query = kwargs["query"]
|
||||||
|
@ -273,6 +328,6 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
)
|
)
|
||||||
|
|
||||||
return ToolInvocationResult(
|
return ToolInvocationResult(
|
||||||
content=result.content,
|
content=result.content or [],
|
||||||
metadata=result.metadata,
|
metadata=result.metadata,
|
||||||
)
|
)
|
||||||
|
|
|
@ -218,7 +218,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="vertexai",
|
adapter_type="vertexai",
|
||||||
pip_packages=["litellm", "google-cloud-aiplatform"],
|
pip_packages=["litellm", "google-cloud-aiplatform", "openai"],
|
||||||
module="llama_stack.providers.remote.inference.vertexai",
|
module="llama_stack.providers.remote.inference.vertexai",
|
||||||
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
|
||||||
|
|
|
@ -6,16 +6,20 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import google.auth.transport.requests
|
||||||
|
from google.auth import default
|
||||||
|
|
||||||
from llama_stack.apis.inference import ChatCompletionRequest
|
from llama_stack.apis.inference import ChatCompletionRequest
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import (
|
||||||
LiteLLMOpenAIMixin,
|
LiteLLMOpenAIMixin,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import VertexAIConfig
|
from .config import VertexAIConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
|
class VertexAIInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
def __init__(self, config: VertexAIConfig) -> None:
|
def __init__(self, config: VertexAIConfig) -> None:
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -27,10 +31,31 @@ class VertexAIInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
# Vertex AI doesn't use API keys, it uses Application Default Credentials
|
"""
|
||||||
# Return empty string to let litellm handle authentication via ADC
|
Get an access token for Vertex AI using Application Default Credentials.
|
||||||
|
|
||||||
|
Vertex AI uses ADC instead of API keys. This method obtains an access token
|
||||||
|
from the default credentials and returns it for use with the OpenAI-compatible client.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Get default credentials - will read from GOOGLE_APPLICATION_CREDENTIALS
|
||||||
|
credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
|
||||||
|
credentials.refresh(google.auth.transport.requests.Request())
|
||||||
|
return str(credentials.token)
|
||||||
|
except Exception:
|
||||||
|
# If we can't get credentials, return empty string to let LiteLLM handle it
|
||||||
|
# This allows the LiteLLM mixin to work with ADC directly
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the Vertex AI OpenAI-compatible API base URL.
|
||||||
|
|
||||||
|
Returns the Vertex AI OpenAI-compatible endpoint URL.
|
||||||
|
Source: https://cloud.google.com/vertex-ai/generative-ai/docs/start/openai
|
||||||
|
"""
|
||||||
|
return f"https://{self.config.location}-aiplatform.googleapis.com/v1/projects/{self.config.project}/locations/{self.config.location}/endpoints/openapi"
|
||||||
|
|
||||||
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
async def _get_params(self, request: ChatCompletionRequest) -> dict[str, Any]:
|
||||||
# Get base parameters from parent
|
# Get base parameters from parent
|
||||||
params = await super()._get_params(request)
|
params = await super()._get_params(request)
|
||||||
|
|
71
scripts/get_setup_env.py
Executable file
71
scripts/get_setup_env.py
Executable file
|
@ -0,0 +1,71 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Small helper script to extract environment variables from a test setup.
|
||||||
|
Used by integration-tests.sh to set environment variables before starting the server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from tests.integration.suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
|
||||||
|
|
||||||
|
|
||||||
|
def get_setup_env_vars(setup_name, suite_name=None):
|
||||||
|
"""
|
||||||
|
Get environment variables for a setup, with optional suite default fallback.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
setup_name: Name of the setup (e.g., 'ollama', 'gpt')
|
||||||
|
suite_name: Optional suite name to get default setup if setup_name is None
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of environment variables
|
||||||
|
"""
|
||||||
|
# If no setup specified, try to get default from suite
|
||||||
|
if not setup_name and suite_name:
|
||||||
|
suite = SUITE_DEFINITIONS.get(suite_name)
|
||||||
|
if suite and suite.default_setup:
|
||||||
|
setup_name = suite.default_setup
|
||||||
|
|
||||||
|
if not setup_name:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
setup = SETUP_DEFINITIONS.get(setup_name)
|
||||||
|
if not setup:
|
||||||
|
print(
|
||||||
|
f"Error: Unknown setup '{setup_name}'. Available: {', '.join(sorted(SETUP_DEFINITIONS.keys()))}",
|
||||||
|
file=sys.stderr,
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
return setup.env
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Extract environment variables from a test setup")
|
||||||
|
parser.add_argument("--setup", help="Setup name (e.g., ollama, gpt)")
|
||||||
|
parser.add_argument("--suite", help="Suite name to get default setup from if --setup not provided")
|
||||||
|
parser.add_argument("--format", choices=["bash", "json"], default="bash", help="Output format (default: bash)")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
env_vars = get_setup_env_vars(args.setup, args.suite)
|
||||||
|
|
||||||
|
if args.format == "bash":
|
||||||
|
# Output as bash export statements
|
||||||
|
for key, value in env_vars.items():
|
||||||
|
print(f"export {key}='{value}'")
|
||||||
|
elif args.format == "json":
|
||||||
|
import json
|
||||||
|
|
||||||
|
print(json.dumps(env_vars))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
|
@ -14,7 +14,7 @@ set -euo pipefail
|
||||||
# Default values
|
# Default values
|
||||||
BRANCH=""
|
BRANCH=""
|
||||||
TEST_SUBDIRS=""
|
TEST_SUBDIRS=""
|
||||||
TEST_PROVIDER="ollama"
|
TEST_SETUP="ollama"
|
||||||
TEST_SUITE="base"
|
TEST_SUITE="base"
|
||||||
TEST_PATTERN=""
|
TEST_PATTERN=""
|
||||||
|
|
||||||
|
@ -27,24 +27,24 @@ Trigger the integration test recording workflow remotely. This way you do not ne
|
||||||
|
|
||||||
OPTIONS:
|
OPTIONS:
|
||||||
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
|
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
|
||||||
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama)
|
-t, --suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
|
||||||
-t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
|
-p, --setup SETUP Test setup to use: vllm, ollama, gpt, etc. (default: ollama)
|
||||||
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
|
-s, --subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
|
||||||
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k
|
-k, --pattern PATTERN Regex pattern to pass to pytest -k
|
||||||
-h, --help Show this help message
|
-h, --help Show this help message
|
||||||
|
|
||||||
EXAMPLES:
|
EXAMPLES:
|
||||||
# Record tests for current branch with agents subdirectory
|
# Record tests for current branch with agents subdirectory
|
||||||
$0 --test-subdirs "agents"
|
$0 --subdirs "agents"
|
||||||
|
|
||||||
# Record tests for specific branch with vision tests
|
# Record tests for specific branch with vision tests
|
||||||
$0 -b my-feature-branch --test-suite vision
|
$0 -b my-feature-branch --suite vision
|
||||||
|
|
||||||
# Record multiple test subdirectories with specific provider
|
# Record multiple test subdirectories with specific setup
|
||||||
$0 --test-subdirs "agents,inference" --test-provider vllm
|
$0 --subdirs "agents,inference" --setup vllm
|
||||||
|
|
||||||
# Record tests matching a specific pattern
|
# Record tests matching a specific pattern
|
||||||
$0 --test-subdirs "inference" --test-pattern "test_streaming"
|
$0 --subdirs "inference" --pattern "test_streaming"
|
||||||
|
|
||||||
EOF
|
EOF
|
||||||
}
|
}
|
||||||
|
@ -63,19 +63,19 @@ while [[ $# -gt 0 ]]; do
|
||||||
BRANCH="$2"
|
BRANCH="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
-s|--test-subdirs)
|
-s|--subdirs)
|
||||||
TEST_SUBDIRS="$2"
|
TEST_SUBDIRS="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
-p|--test-provider)
|
-p|--setup)
|
||||||
TEST_PROVIDER="$2"
|
TEST_SETUP="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
-t|--test-suite)
|
-t|--suite)
|
||||||
TEST_SUITE="$2"
|
TEST_SUITE="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
-k|--test-pattern)
|
-k|--pattern)
|
||||||
TEST_PATTERN="$2"
|
TEST_PATTERN="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
|
@ -93,21 +93,16 @@ done
|
||||||
|
|
||||||
# Validate required parameters
|
# Validate required parameters
|
||||||
if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
|
if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
|
||||||
echo "Error: --test-subdirs or --test-suite is required"
|
echo "Error: --subdirs or --suite is required"
|
||||||
echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
|
echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
|
||||||
echo " $0 --test-subdirs \"agents,inference\""
|
echo " $0 --subdirs \"agents,inference\""
|
||||||
echo " $0 --test-suite vision"
|
echo " $0 --suite vision"
|
||||||
echo ""
|
echo ""
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Validate test provider
|
# Validate test setup (optional - setups are validated by the workflow itself)
|
||||||
if [[ "$TEST_PROVIDER" != "vllm" && "$TEST_PROVIDER" != "ollama" ]]; then
|
# Common setups: ollama, vllm, gpt, etc.
|
||||||
echo "❌ Error: Invalid test provider '$TEST_PROVIDER'"
|
|
||||||
echo " Supported providers: vllm, ollama"
|
|
||||||
echo " Example: $0 --test-subdirs \"agents\" --test-provider vllm"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Check if required tools are installed
|
# Check if required tools are installed
|
||||||
if ! command -v gh &> /dev/null; then
|
if ! command -v gh &> /dev/null; then
|
||||||
|
@ -237,7 +232,7 @@ fi
|
||||||
# Build the workflow dispatch command
|
# Build the workflow dispatch command
|
||||||
echo "Triggering integration test recording workflow..."
|
echo "Triggering integration test recording workflow..."
|
||||||
echo "Branch: $BRANCH"
|
echo "Branch: $BRANCH"
|
||||||
echo "Test provider: $TEST_PROVIDER"
|
echo "Test setup: $TEST_SETUP"
|
||||||
echo "Test subdirs: $TEST_SUBDIRS"
|
echo "Test subdirs: $TEST_SUBDIRS"
|
||||||
echo "Test suite: $TEST_SUITE"
|
echo "Test suite: $TEST_SUITE"
|
||||||
echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
|
echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
|
||||||
|
@ -245,16 +240,16 @@ echo ""
|
||||||
|
|
||||||
# Prepare inputs for gh workflow run
|
# Prepare inputs for gh workflow run
|
||||||
if [[ -n "$TEST_SUBDIRS" ]]; then
|
if [[ -n "$TEST_SUBDIRS" ]]; then
|
||||||
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
|
INPUTS="-f subdirs='$TEST_SUBDIRS'"
|
||||||
fi
|
fi
|
||||||
if [[ -n "$TEST_PROVIDER" ]]; then
|
if [[ -n "$TEST_SETUP" ]]; then
|
||||||
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'"
|
INPUTS="$INPUTS -f test-setup='$TEST_SETUP'"
|
||||||
fi
|
fi
|
||||||
if [[ -n "$TEST_SUITE" ]]; then
|
if [[ -n "$TEST_SUITE" ]]; then
|
||||||
INPUTS="$INPUTS -f test-suite='$TEST_SUITE'"
|
INPUTS="$INPUTS -f suite='$TEST_SUITE'"
|
||||||
fi
|
fi
|
||||||
if [[ -n "$TEST_PATTERN" ]]; then
|
if [[ -n "$TEST_PATTERN" ]]; then
|
||||||
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'"
|
INPUTS="$INPUTS -f pattern='$TEST_PATTERN'"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Run the workflow
|
# Run the workflow
|
||||||
|
|
|
@ -13,10 +13,10 @@ set -euo pipefail
|
||||||
|
|
||||||
# Default values
|
# Default values
|
||||||
STACK_CONFIG=""
|
STACK_CONFIG=""
|
||||||
PROVIDER=""
|
TEST_SUITE="base"
|
||||||
|
TEST_SETUP=""
|
||||||
TEST_SUBDIRS=""
|
TEST_SUBDIRS=""
|
||||||
TEST_PATTERN=""
|
TEST_PATTERN=""
|
||||||
TEST_SUITE="base"
|
|
||||||
INFERENCE_MODE="replay"
|
INFERENCE_MODE="replay"
|
||||||
EXTRA_PARAMS=""
|
EXTRA_PARAMS=""
|
||||||
|
|
||||||
|
@ -27,29 +27,30 @@ Usage: $0 [OPTIONS]
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
--stack-config STRING Stack configuration to use (required)
|
--stack-config STRING Stack configuration to use (required)
|
||||||
--provider STRING Provider to use (ollama, vllm, etc.) (required)
|
--suite STRING Test suite to run (default: 'base')
|
||||||
--test-suite STRING Comma-separated list of test suites to run (default: 'base')
|
--setup STRING Test setup (models, env) to use (e.g., 'ollama', 'ollama-vision', 'gpt', 'vllm')
|
||||||
--inference-mode STRING Inference mode: record or replay (default: replay)
|
--inference-mode STRING Inference mode: record or replay (default: replay)
|
||||||
--test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
--subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
||||||
--test-pattern STRING Regex pattern to pass to pytest -k
|
--pattern STRING Regex pattern to pass to pytest -k
|
||||||
--help Show this help message
|
--help Show this help message
|
||||||
|
|
||||||
Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options.
|
Suites are defined in tests/integration/suites.py and define which tests to run.
|
||||||
|
Setups are defined in tests/integration/setups.py and provide global configuration (models, env).
|
||||||
|
|
||||||
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
|
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Basic inference tests with ollama
|
# Basic inference tests with ollama
|
||||||
$0 --stack-config server:ci-tests --provider ollama
|
$0 --stack-config server:ci-tests --suite base --setup ollama
|
||||||
|
|
||||||
# Multiple test directories with vllm
|
# Multiple test directories with vllm
|
||||||
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents'
|
$0 --stack-config server:ci-tests --subdirs 'inference,agents' --setup vllm
|
||||||
|
|
||||||
# Vision tests with ollama
|
# Vision tests with ollama
|
||||||
$0 --stack-config server:ci-tests --provider ollama --test-suite vision
|
$0 --stack-config server:ci-tests --suite vision # default setup for this suite is ollama-vision
|
||||||
|
|
||||||
# Record mode for updating test recordings
|
# Record mode for updating test recordings
|
||||||
$0 --stack-config server:ci-tests --provider ollama --inference-mode record
|
$0 --stack-config server:ci-tests --suite base --inference-mode record
|
||||||
EOF
|
EOF
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -60,15 +61,15 @@ while [[ $# -gt 0 ]]; do
|
||||||
STACK_CONFIG="$2"
|
STACK_CONFIG="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--provider)
|
--setup)
|
||||||
PROVIDER="$2"
|
TEST_SETUP="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--test-subdirs)
|
--subdirs)
|
||||||
TEST_SUBDIRS="$2"
|
TEST_SUBDIRS="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--test-suite)
|
--suite)
|
||||||
TEST_SUITE="$2"
|
TEST_SUITE="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
|
@ -76,7 +77,7 @@ while [[ $# -gt 0 ]]; do
|
||||||
INFERENCE_MODE="$2"
|
INFERENCE_MODE="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--test-pattern)
|
--pattern)
|
||||||
TEST_PATTERN="$2"
|
TEST_PATTERN="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
|
@ -96,11 +97,13 @@ done
|
||||||
# Validate required parameters
|
# Validate required parameters
|
||||||
if [[ -z "$STACK_CONFIG" ]]; then
|
if [[ -z "$STACK_CONFIG" ]]; then
|
||||||
echo "Error: --stack-config is required"
|
echo "Error: --stack-config is required"
|
||||||
|
usage
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ -z "$PROVIDER" ]]; then
|
if [[ -z "$TEST_SETUP" && -n "$TEST_SUBDIRS" ]]; then
|
||||||
echo "Error: --provider is required"
|
echo "Error: --test-setup is required when --test-subdirs is provided"
|
||||||
|
usage
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -111,7 +114,7 @@ fi
|
||||||
|
|
||||||
echo "=== Llama Stack Integration Test Runner ==="
|
echo "=== Llama Stack Integration Test Runner ==="
|
||||||
echo "Stack Config: $STACK_CONFIG"
|
echo "Stack Config: $STACK_CONFIG"
|
||||||
echo "Provider: $PROVIDER"
|
echo "Setup: $TEST_SETUP"
|
||||||
echo "Inference Mode: $INFERENCE_MODE"
|
echo "Inference Mode: $INFERENCE_MODE"
|
||||||
echo "Test Suite: $TEST_SUITE"
|
echo "Test Suite: $TEST_SUITE"
|
||||||
echo "Test Subdirs: $TEST_SUBDIRS"
|
echo "Test Subdirs: $TEST_SUBDIRS"
|
||||||
|
@ -129,21 +132,25 @@ echo ""
|
||||||
|
|
||||||
# Set environment variables
|
# Set environment variables
|
||||||
export LLAMA_STACK_CLIENT_TIMEOUT=300
|
export LLAMA_STACK_CLIENT_TIMEOUT=300
|
||||||
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
|
|
||||||
|
|
||||||
# Configure provider-specific settings
|
|
||||||
if [[ "$PROVIDER" == "ollama" ]]; then
|
|
||||||
export OLLAMA_URL="http://0.0.0.0:11434"
|
|
||||||
export TEXT_MODEL="ollama/llama3.2:3b-instruct-fp16"
|
|
||||||
export SAFETY_MODEL="ollama/llama-guard3:1b"
|
|
||||||
EXTRA_PARAMS="--safety-shield=llama-guard"
|
|
||||||
else
|
|
||||||
export VLLM_URL="http://localhost:8000/v1"
|
|
||||||
export TEXT_MODEL="vllm/meta-llama/Llama-3.2-1B-Instruct"
|
|
||||||
EXTRA_PARAMS=""
|
|
||||||
fi
|
|
||||||
|
|
||||||
THIS_DIR=$(dirname "$0")
|
THIS_DIR=$(dirname "$0")
|
||||||
|
|
||||||
|
if [[ -n "$TEST_SETUP" ]]; then
|
||||||
|
EXTRA_PARAMS="--setup=$TEST_SETUP"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Apply setup-specific environment variables (needed for server startup and tests)
|
||||||
|
echo "=== Applying Setup Environment Variables ==="
|
||||||
|
|
||||||
|
# the server needs this
|
||||||
|
export LLAMA_STACK_TEST_INFERENCE_MODE="$INFERENCE_MODE"
|
||||||
|
|
||||||
|
SETUP_ENV=$(PYTHONPATH=$THIS_DIR/.. python "$THIS_DIR/get_setup_env.py" --suite "$TEST_SUITE" --setup "$TEST_SETUP" --format bash)
|
||||||
|
echo "Setting up environment variables:"
|
||||||
|
echo "$SETUP_ENV"
|
||||||
|
eval "$SETUP_ENV"
|
||||||
|
echo ""
|
||||||
|
|
||||||
ROOT_DIR="$THIS_DIR/.."
|
ROOT_DIR="$THIS_DIR/.."
|
||||||
cd $ROOT_DIR
|
cd $ROOT_DIR
|
||||||
|
|
||||||
|
@ -162,6 +169,18 @@ fi
|
||||||
|
|
||||||
# Start Llama Stack Server if needed
|
# Start Llama Stack Server if needed
|
||||||
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
||||||
|
stop_server() {
|
||||||
|
echo "Stopping Llama Stack Server..."
|
||||||
|
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
|
||||||
|
if [[ -n "$pids" ]]; then
|
||||||
|
echo "Killing Llama Stack Server processes: $pids"
|
||||||
|
kill -9 $pids
|
||||||
|
else
|
||||||
|
echo "No Llama Stack Server processes found ?!"
|
||||||
|
fi
|
||||||
|
echo "Llama Stack Server stopped"
|
||||||
|
}
|
||||||
|
|
||||||
# check if server is already running
|
# check if server is already running
|
||||||
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
if curl -s http://localhost:8321/v1/health 2>/dev/null | grep -q "OK"; then
|
||||||
echo "Llama Stack Server is already running, skipping start"
|
echo "Llama Stack Server is already running, skipping start"
|
||||||
|
@ -185,14 +204,16 @@ if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
||||||
done
|
done
|
||||||
echo ""
|
echo ""
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
trap stop_server EXIT ERR INT TERM
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Run tests
|
# Run tests
|
||||||
echo "=== Running Integration Tests ==="
|
echo "=== Running Integration Tests ==="
|
||||||
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
|
EXCLUDE_TESTS="builtin_tool or safety_with_image or code_interpreter or test_rag"
|
||||||
|
|
||||||
# Additional exclusions for vllm provider
|
# Additional exclusions for vllm setup
|
||||||
if [[ "$PROVIDER" == "vllm" ]]; then
|
if [[ "$TEST_SETUP" == "vllm" ]]; then
|
||||||
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
|
EXCLUDE_TESTS="${EXCLUDE_TESTS} or test_inference_store_tool_calls"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
@ -229,20 +250,22 @@ if [[ -n "$TEST_SUBDIRS" ]]; then
|
||||||
echo "Total test files: $(echo $TEST_FILES | wc -w)"
|
echo "Total test files: $(echo $TEST_FILES | wc -w)"
|
||||||
|
|
||||||
PYTEST_TARGET="$TEST_FILES"
|
PYTEST_TARGET="$TEST_FILES"
|
||||||
EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2"
|
|
||||||
else
|
else
|
||||||
PYTEST_TARGET="tests/integration/"
|
PYTEST_TARGET="tests/integration/"
|
||||||
EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
|
EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
|
set -x
|
||||||
pytest -s -v $PYTEST_TARGET \
|
pytest -s -v $PYTEST_TARGET \
|
||||||
--stack-config="$STACK_CONFIG" \
|
--stack-config="$STACK_CONFIG" \
|
||||||
|
--inference-mode="$INFERENCE_MODE" \
|
||||||
-k "$PYTEST_PATTERN" \
|
-k "$PYTEST_PATTERN" \
|
||||||
$EXTRA_PARAMS \
|
$EXTRA_PARAMS \
|
||||||
--color=yes \
|
--color=yes \
|
||||||
--capture=tee-sys
|
--capture=tee-sys
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
|
set +x
|
||||||
set -e
|
set -e
|
||||||
|
|
||||||
if [ $exit_code -eq 0 ]; then
|
if [ $exit_code -eq 0 ]; then
|
||||||
|
@ -260,18 +283,5 @@ echo "=== System Resources After Tests ==="
|
||||||
free -h 2>/dev/null || echo "free command not available"
|
free -h 2>/dev/null || echo "free command not available"
|
||||||
df -h
|
df -h
|
||||||
|
|
||||||
# stop server
|
|
||||||
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
|
||||||
echo "Stopping Llama Stack Server..."
|
|
||||||
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
|
|
||||||
if [[ -n "$pids" ]]; then
|
|
||||||
echo "Killing Llama Stack Server processes: $pids"
|
|
||||||
kill -9 $pids
|
|
||||||
else
|
|
||||||
echo "No Llama Stack Server processes found ?!"
|
|
||||||
fi
|
|
||||||
echo "Llama Stack Server stopped"
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== Integration Tests Complete ==="
|
echo "=== Integration Tests Complete ==="
|
||||||
|
|
|
@ -6,9 +6,7 @@ Integration tests verify complete workflows across different providers using Lla
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Run all integration tests with existing recordings
|
# Run all integration tests with existing recordings
|
||||||
LLAMA_STACK_TEST_INFERENCE_MODE=replay \
|
uv run --group test \
|
||||||
LLAMA_STACK_TEST_RECORDING_DIR=tests/integration/recordings \
|
|
||||||
uv run --group test \
|
|
||||||
pytest -sv tests/integration/ --stack-config=starter
|
pytest -sv tests/integration/ --stack-config=starter
|
||||||
```
|
```
|
||||||
|
|
||||||
|
@ -42,25 +40,35 @@ Model parameters can be influenced by the following options:
|
||||||
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
|
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
|
||||||
if no model is specified.
|
if no model is specified.
|
||||||
|
|
||||||
### Suites (fast selection + sane defaults)
|
### Suites and Setups
|
||||||
|
|
||||||
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
|
- `--suite`: single named suite that narrows which tests are collected.
|
||||||
- Available suites:
|
- Available suites:
|
||||||
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
|
- `base`: collects most tests (excludes responses and post_training)
|
||||||
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
|
- `responses`: collects tests under `tests/integration/responses` (needs strong tool-calling models)
|
||||||
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suite’s text model.
|
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`
|
||||||
|
- `--setup`: global configuration that can be used with any suite. Setups prefill model/env defaults; explicit CLI flags always win.
|
||||||
|
- Available setups:
|
||||||
|
- `ollama`: Local Ollama provider with lightweight models (sets OLLAMA_URL, uses llama3.2:3b-instruct-fp16)
|
||||||
|
- `vllm`: VLLM provider for efficient local inference (sets VLLM_URL, uses Llama-3.2-1B-Instruct)
|
||||||
|
- `gpt`: OpenAI GPT models for high-quality responses (uses gpt-4o)
|
||||||
|
- `claude`: Anthropic Claude models for high-quality responses (uses claude-3-5-sonnet)
|
||||||
|
|
||||||
Examples:
|
Examples
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Fast responses run with defaults
|
# Fast responses run with a strong tool-calling model
|
||||||
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
|
pytest -s -v tests/integration --stack-config=server:starter --suite=responses --setup=gpt
|
||||||
|
|
||||||
# Fast single-file vision run with defaults
|
# Fast single-file vision run with Ollama defaults
|
||||||
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
|
pytest -s -v tests/integration --stack-config=server:starter --suite=vision --setup=ollama
|
||||||
|
|
||||||
# Combine suites and override a default
|
# Base suite with VLLM for performance
|
||||||
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
|
pytest -s -v tests/integration --stack-config=server:starter --suite=base --setup=vllm
|
||||||
|
|
||||||
|
# Override a default from setup
|
||||||
|
pytest -s -v tests/integration --stack-config=server:starter \
|
||||||
|
--suite=responses --setup=gpt --embedding-model=text-embedding-3-small
|
||||||
```
|
```
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
@ -127,14 +135,13 @@ pytest tests/integration/
|
||||||
### RECORD Mode
|
### RECORD Mode
|
||||||
Captures API interactions for later replay:
|
Captures API interactions for later replay:
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_TEST_INFERENCE_MODE=record \
|
pytest tests/integration/inference/test_new_feature.py --inference-mode=record
|
||||||
pytest tests/integration/inference/test_new_feature.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### LIVE Mode
|
### LIVE Mode
|
||||||
Tests make real API calls (but not recorded):
|
Tests make real API calls (but not recorded):
|
||||||
```bash
|
```bash
|
||||||
LLAMA_STACK_TEST_INFERENCE_MODE=live pytest tests/integration/
|
pytest tests/integration/ --inference-mode=live
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, the recording directory is `tests/integration/recordings`. You can override this by setting the `LLAMA_STACK_TEST_RECORDING_DIR` environment variable.
|
By default, the recording directory is `tests/integration/recordings`. You can override this by setting the `LLAMA_STACK_TEST_RECORDING_DIR` environment variable.
|
||||||
|
@ -155,15 +162,14 @@ cat recordings/responses/abc123.json | jq '.'
|
||||||
#### Remote Re-recording (Recommended)
|
#### Remote Re-recording (Recommended)
|
||||||
Use the automated workflow script for easier re-recording:
|
Use the automated workflow script for easier re-recording:
|
||||||
```bash
|
```bash
|
||||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference,agents"
|
./scripts/github/schedule-record-workflow.sh --subdirs "inference,agents"
|
||||||
```
|
```
|
||||||
See the [main testing guide](../README.md#remote-re-recording-recommended) for full details.
|
See the [main testing guide](../README.md#remote-re-recording-recommended) for full details.
|
||||||
|
|
||||||
#### Local Re-recording
|
#### Local Re-recording
|
||||||
```bash
|
```bash
|
||||||
# Re-record specific tests
|
# Re-record specific tests
|
||||||
LLAMA_STACK_TEST_INFERENCE_MODE=record \
|
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py --inference-mode=record
|
||||||
pytest -s -v --stack-config=server:starter tests/integration/inference/test_modified.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client.
|
Note that when re-recording tests, you must use a Stack pointing to a server (i.e., `server:starter`). This subtlety exists because the set of tests run in server are a superset of the set of tests run in the library client.
|
||||||
|
|
|
@ -15,7 +15,7 @@ from dotenv import load_dotenv
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
from .suites import SUITE_DEFINITIONS
|
from .suites import SETUP_DEFINITIONS, SUITE_DEFINITIONS
|
||||||
|
|
||||||
logger = get_logger(__name__, category="tests")
|
logger = get_logger(__name__, category="tests")
|
||||||
|
|
||||||
|
@ -63,19 +63,33 @@ def pytest_configure(config):
|
||||||
key, value = env_var.split("=", 1)
|
key, value = env_var.split("=", 1)
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
suites_raw = config.getoption("--suite")
|
inference_mode = config.getoption("--inference-mode")
|
||||||
suites: list[str] = []
|
os.environ["LLAMA_STACK_TEST_INFERENCE_MODE"] = inference_mode
|
||||||
if suites_raw:
|
|
||||||
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
suite = config.getoption("--suite")
|
||||||
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
|
if suite:
|
||||||
if unknown:
|
if suite not in SUITE_DEFINITIONS:
|
||||||
|
raise pytest.UsageError(f"Unknown suite: {suite}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}")
|
||||||
|
|
||||||
|
# Apply setups (global parameterizations): env + defaults
|
||||||
|
setup = config.getoption("--setup")
|
||||||
|
if suite and not setup:
|
||||||
|
setup = SUITE_DEFINITIONS[suite].default_setup
|
||||||
|
|
||||||
|
if setup:
|
||||||
|
if setup not in SETUP_DEFINITIONS:
|
||||||
raise pytest.UsageError(
|
raise pytest.UsageError(
|
||||||
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
|
f"Unknown setup '{setup}'. Available: {', '.join(sorted(SETUP_DEFINITIONS.keys()))}"
|
||||||
)
|
)
|
||||||
for suite in suites:
|
|
||||||
suite_def = SUITE_DEFINITIONS.get(suite, {})
|
setup_obj = SETUP_DEFINITIONS[setup]
|
||||||
defaults: dict = suite_def.get("defaults", {})
|
logger.info(f"Applying setup '{setup}'{' for suite ' + suite if suite else ''}")
|
||||||
for dest, value in defaults.items():
|
# Apply env first
|
||||||
|
for k, v in setup_obj.env.items():
|
||||||
|
if k not in os.environ:
|
||||||
|
os.environ[k] = str(v)
|
||||||
|
# Apply defaults if not provided explicitly
|
||||||
|
for dest, value in setup_obj.defaults.items():
|
||||||
current = getattr(config.option, dest, None)
|
current = getattr(config.option, dest, None)
|
||||||
if not current:
|
if not current:
|
||||||
setattr(config.option, dest, value)
|
setattr(config.option, dest, value)
|
||||||
|
@ -120,6 +134,13 @@ def pytest_addoption(parser):
|
||||||
default=384,
|
default=384,
|
||||||
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-mode",
|
||||||
|
help="Inference mode: { record, replay, live } (default: replay)",
|
||||||
|
choices=["record", "replay", "live"],
|
||||||
|
default="replay",
|
||||||
|
)
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--report",
|
"--report",
|
||||||
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
||||||
|
@ -127,14 +148,18 @@ def pytest_addoption(parser):
|
||||||
|
|
||||||
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
|
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
|
||||||
suite_help = (
|
suite_help = (
|
||||||
"Comma-separated integration test suites to narrow collection and prefill defaults. "
|
f"Single test suite to run (narrows collection). Available: {available_suites}. Example: --suite=responses"
|
||||||
"Available: "
|
|
||||||
f"{available_suites}. "
|
|
||||||
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
|
|
||||||
"Examples: --suite=responses or --suite=responses,vision."
|
|
||||||
)
|
)
|
||||||
parser.addoption("--suite", help=suite_help)
|
parser.addoption("--suite", help=suite_help)
|
||||||
|
|
||||||
|
# Global setups for any suite
|
||||||
|
available_setups = ", ".join(sorted(SETUP_DEFINITIONS.keys()))
|
||||||
|
setup_help = (
|
||||||
|
f"Global test setup configuration. Available: {available_setups}. "
|
||||||
|
"Can be used with any suite. Example: --setup=ollama"
|
||||||
|
)
|
||||||
|
parser.addoption("--setup", help=setup_help)
|
||||||
|
|
||||||
|
|
||||||
MODEL_SHORT_IDS = {
|
MODEL_SHORT_IDS = {
|
||||||
"meta-llama/Llama-3.2-3B-Instruct": "3B",
|
"meta-llama/Llama-3.2-3B-Instruct": "3B",
|
||||||
|
@ -221,16 +246,12 @@ pytest_plugins = ["tests.integration.fixtures.common"]
|
||||||
|
|
||||||
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
||||||
"""Skip collecting paths outside the selected suite roots for speed."""
|
"""Skip collecting paths outside the selected suite roots for speed."""
|
||||||
suites_raw = config.getoption("--suite")
|
suite = config.getoption("--suite")
|
||||||
if not suites_raw:
|
if not suite:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
sobj = SUITE_DEFINITIONS.get(suite)
|
||||||
roots: list[str] = []
|
roots: list[str] = sobj.get("roots", []) if isinstance(sobj, dict) else getattr(sobj, "roots", [])
|
||||||
for name in names:
|
|
||||||
suite_def = SUITE_DEFINITIONS.get(name)
|
|
||||||
if suite_def:
|
|
||||||
roots.extend(suite_def.get("roots", []))
|
|
||||||
if not roots:
|
if not roots:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
@ -76,6 +76,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id):
|
||||||
"remote::gemini",
|
"remote::gemini",
|
||||||
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields
|
# https://docs.anthropic.com/en/api/openai-sdk#simple-fields
|
||||||
"remote::anthropic",
|
"remote::anthropic",
|
||||||
|
"remote::vertexai",
|
||||||
|
# Error code: 400 - [{'error': {'code': 400, 'message': 'Unable to submit request because candidateCount must be 1 but
|
||||||
|
# the entered value was 2. Update the candidateCount value and try again.', 'status': 'INVALID_ARGUMENT'}
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
||||||
|
|
||||||
|
|
|
@ -8,46 +8,112 @@
|
||||||
# For example:
|
# For example:
|
||||||
#
|
#
|
||||||
# ```bash
|
# ```bash
|
||||||
# pytest tests/integration/ --suite=vision
|
# pytest tests/integration/ --suite=vision --setup=ollama
|
||||||
# ```
|
# ```
|
||||||
#
|
#
|
||||||
# Each suite can:
|
"""
|
||||||
# - restrict collection to specific roots (dirs or files)
|
Each suite defines what to run (roots). Suites can be run with different global setups defined in setups.py.
|
||||||
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
|
Setups provide environment variables and model defaults that can be reused across multiple suites.
|
||||||
|
|
||||||
|
CLI examples:
|
||||||
|
pytest tests/integration --suite=responses --setup=gpt
|
||||||
|
pytest tests/integration --suite=vision --setup=ollama
|
||||||
|
pytest tests/integration --suite=base --setup=vllm
|
||||||
|
"""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
this_dir = Path(__file__).parent
|
this_dir = Path(__file__).parent
|
||||||
default_roots = [
|
|
||||||
|
|
||||||
|
class Suite(BaseModel):
|
||||||
|
name: str
|
||||||
|
roots: list[str]
|
||||||
|
default_setup: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Setup(BaseModel):
|
||||||
|
"""A reusable test configuration with environment and CLI defaults."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
defaults: dict[str, str] = Field(default_factory=dict)
|
||||||
|
env: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# Global setups - can be used with any suite "technically" but in reality, some setups might work
|
||||||
|
# only for specific test suites.
|
||||||
|
SETUP_DEFINITIONS: dict[str, Setup] = {
|
||||||
|
"ollama": Setup(
|
||||||
|
name="ollama",
|
||||||
|
description="Local Ollama provider with text + safety models",
|
||||||
|
env={
|
||||||
|
"OLLAMA_URL": "http://0.0.0.0:11434",
|
||||||
|
"SAFETY_MODEL": "ollama/llama-guard3:1b",
|
||||||
|
},
|
||||||
|
defaults={
|
||||||
|
"text_model": "ollama/llama3.2:3b-instruct-fp16",
|
||||||
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
"safety_model": "ollama/llama-guard3:1b",
|
||||||
|
"safety_shield": "llama-guard",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"ollama-vision": Setup(
|
||||||
|
name="ollama",
|
||||||
|
description="Local Ollama provider with a vision model",
|
||||||
|
env={
|
||||||
|
"OLLAMA_URL": "http://0.0.0.0:11434",
|
||||||
|
},
|
||||||
|
defaults={
|
||||||
|
"vision_model": "ollama/llama3.2-vision:11b",
|
||||||
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"vllm": Setup(
|
||||||
|
name="vllm",
|
||||||
|
description="vLLM provider with a text model",
|
||||||
|
env={
|
||||||
|
"VLLM_URL": "http://localhost:8000/v1",
|
||||||
|
},
|
||||||
|
defaults={
|
||||||
|
"text_model": "vllm/meta-llama/Llama-3.2-1B-Instruct",
|
||||||
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
"gpt": Setup(
|
||||||
|
name="gpt",
|
||||||
|
description="OpenAI GPT models for high-quality responses and tool calling",
|
||||||
|
defaults={
|
||||||
|
"text_model": "openai/gpt-4o",
|
||||||
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
base_roots = [
|
||||||
str(p)
|
str(p)
|
||||||
for p in this_dir.glob("*")
|
for p in this_dir.glob("*")
|
||||||
if p.is_dir()
|
if p.is_dir()
|
||||||
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
|
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
|
||||||
]
|
]
|
||||||
|
|
||||||
SUITE_DEFINITIONS: dict[str, dict] = {
|
SUITE_DEFINITIONS: dict[str, Suite] = {
|
||||||
"base": {
|
"base": Suite(
|
||||||
"description": "Base suite that includes most tests but runs them with a text Ollama model",
|
name="base",
|
||||||
"roots": default_roots,
|
roots=base_roots,
|
||||||
"defaults": {
|
default_setup="ollama",
|
||||||
"text_model": "ollama/llama3.2:3b-instruct-fp16",
|
),
|
||||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
"responses": Suite(
|
||||||
},
|
name="responses",
|
||||||
},
|
roots=["tests/integration/responses"],
|
||||||
"responses": {
|
default_setup="gpt",
|
||||||
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
|
),
|
||||||
"roots": ["tests/integration/responses"],
|
"vision": Suite(
|
||||||
"defaults": {
|
name="vision",
|
||||||
"text_model": "openai/gpt-4o",
|
roots=["tests/integration/inference/test_vision_inference.py"],
|
||||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
default_setup="ollama-vision",
|
||||||
},
|
),
|
||||||
},
|
|
||||||
"vision": {
|
|
||||||
"description": "Suite that includes only the vision tests",
|
|
||||||
"roots": ["tests/integration/inference/test_vision_inference.py"],
|
|
||||||
"defaults": {
|
|
||||||
"vision_model": "ollama/llama3.2-vision:11b",
|
|
||||||
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -183,6 +183,110 @@ def test_vector_db_insert_from_url_and_query(
|
||||||
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
assert any("llama2" in chunk.content.lower() for chunk in response2.chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_openai_apis(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_openai_vector_db"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
# different document formats that should work with OpenAI APIs
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="text-doc",
|
||||||
|
content="This is a plain text document about machine learning algorithms.",
|
||||||
|
metadata={"type": "text", "category": "AI"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="url-doc",
|
||||||
|
content="https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/chat.rst",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={"type": "url", "source": "pytorch"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="data-url-doc",
|
||||||
|
content="data:text/plain;base64,VGhpcyBpcyBhIGRhdGEgVVJMIGRvY3VtZW50IGFib3V0IGRlZXAgbGVhcm5pbmcu", # "This is a data URL document about deep learning."
|
||||||
|
metadata={"type": "data_url", "encoding": "base64"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
files_list = client_with_empty_registry.files.list()
|
||||||
|
assert len(files_list.data) >= len(documents), (
|
||||||
|
f"Expected at least {len(documents)} files, got {len(files_list.data)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
|
||||||
|
vector_store_id=actual_vector_db_id
|
||||||
|
)
|
||||||
|
assert len(vector_store_files.data) >= len(documents), f"Expected at least {len(documents)} files in vector store"
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="Tell me about machine learning and deep learning",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "machine learning" in content_text or "deep learning" in content_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_exception_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_exception_handling"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="valid-doc",
|
||||||
|
content="This is a valid document that should be processed successfully.",
|
||||||
|
metadata={"status": "valid"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="invalid-url-doc",
|
||||||
|
content="https://nonexistent-domain-12345.com/invalid.txt",
|
||||||
|
metadata={"status": "invalid_url"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="another-valid-doc",
|
||||||
|
content="This is another valid document for testing resilience.",
|
||||||
|
metadata={"status": "valid"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="valid document",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "valid document" in content_text
|
||||||
|
|
||||||
|
|
||||||
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
providers = [p for p in client_with_empty_registry.providers.list() if p.api == "vector_io"]
|
||||||
assert len(providers) > 0
|
assert len(providers) > 0
|
||||||
|
@ -249,3 +353,107 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
||||||
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_query_generation(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_query_generation_db"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="ai-doc",
|
||||||
|
content="Artificial intelligence and machine learning are transforming technology.",
|
||||||
|
metadata={"category": "AI"},
|
||||||
|
),
|
||||||
|
Document(
|
||||||
|
document_id="banana-doc",
|
||||||
|
content="Don't bring a banana to a knife fight.",
|
||||||
|
metadata={"category": "wisdom"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="Tell me about AI",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "artificial intelligence" in content_text or "machine learning" in content_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_tool_pdf_data_url_handling(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
|
vector_db_id = "test_pdf_data_url_db"
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model_id,
|
||||||
|
embedding_dimension=embedding_dimension,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
|
sample_pdf = b"%PDF-1.3\n3 0 obj\n<</Type /Page\n/Parent 1 0 R\n/Resources 2 0 R\n/Contents 4 0 R>>\nendobj\n4 0 obj\n<</Filter /FlateDecode /Length 115>>\nstream\nx\x9c\x15\xcc1\x0e\x820\x18@\xe1\x9dS\xbcM]jk$\xd5\xd5(\x83!\x86\xa1\x17\xf8\xa3\xa5`LIh+\xd7W\xc6\xf7\r\xef\xc0\xbd\xd2\xaa\xb6,\xd5\xc5\xb1o\x0c\xa6VZ\xe3znn%\xf3o\xab\xb1\xe7\xa3:Y\xdc\x8bm\xeb\xf3&1\xc8\xd7\xd3\x97\xc82\xe6\x81\x87\xe42\xcb\x87Vb(\x12<\xdd<=}Jc\x0cL\x91\xee\xda$\xb5\xc3\xbd\xd7\xe9\x0f\x8d\x97 $\nendstream\nendobj\n1 0 obj\n<</Type /Pages\n/Kids [3 0 R ]\n/Count 1\n/MediaBox [0 0 595.28 841.89]\n>>\nendobj\n5 0 obj\n<</Type /Font\n/BaseFont /Helvetica\n/Subtype /Type1\n/Encoding /WinAnsiEncoding\n>>\nendobj\n2 0 obj\n<<\n/ProcSet [/PDF /Text /ImageB /ImageC /ImageI]\n/Font <<\n/F1 5 0 R\n>>\n/XObject <<\n>>\n>>\nendobj\n6 0 obj\n<<\n/Producer (PyFPDF 1.7.2 http://pyfpdf.googlecode.com/)\n/Title (This is a sample title.)\n/Author (Llama Stack Developers)\n/CreationDate (D:20250312165548)\n>>\nendobj\n7 0 obj\n<<\n/Type /Catalog\n/Pages 1 0 R\n/OpenAction [3 0 R /FitH null]\n/PageLayout /OneColumn\n>>\nendobj\nxref\n0 8\n0000000000 65535 f \n0000000272 00000 n \n0000000455 00000 n \n0000000009 00000 n \n0000000087 00000 n \n0000000359 00000 n \n0000000559 00000 n \n0000000734 00000 n \ntrailer\n<<\n/Size 8\n/Root 7 0 R\n/Info 6 0 R\n>>\nstartxref\n837\n%%EOF\n"
|
||||||
|
|
||||||
|
import base64
|
||||||
|
|
||||||
|
pdf_base64 = base64.b64encode(sample_pdf).decode("utf-8")
|
||||||
|
pdf_data_url = f"data:application/pdf;base64,{pdf_base64}"
|
||||||
|
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id="test-pdf-data-url",
|
||||||
|
content=pdf_data_url,
|
||||||
|
metadata={"type": "pdf", "source": "data_url"},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=actual_vector_db_id,
|
||||||
|
chunk_size_in_tokens=256,
|
||||||
|
)
|
||||||
|
|
||||||
|
files_list = client_with_empty_registry.files.list()
|
||||||
|
assert len(files_list.data) >= 1, "PDF should have been uploaded to Files API"
|
||||||
|
|
||||||
|
pdf_file = None
|
||||||
|
for file in files_list.data:
|
||||||
|
if file.filename and "test-pdf-data-url" in file.filename:
|
||||||
|
pdf_file = file
|
||||||
|
break
|
||||||
|
|
||||||
|
assert pdf_file is not None, "PDF file should be found in Files API"
|
||||||
|
assert pdf_file.bytes == len(sample_pdf), f"File size should match original PDF ({len(sample_pdf)} bytes)"
|
||||||
|
|
||||||
|
file_content = client_with_empty_registry.files.retrieve_content(pdf_file.id)
|
||||||
|
assert file_content.startswith(b"%PDF-"), "Retrieved file should be a valid PDF"
|
||||||
|
|
||||||
|
vector_store_files = client_with_empty_registry.vector_io.openai_list_files_in_vector_store(
|
||||||
|
vector_store_id=actual_vector_db_id
|
||||||
|
)
|
||||||
|
assert len(vector_store_files.data) >= 1, "PDF should be attached to vector store"
|
||||||
|
|
||||||
|
response = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
|
vector_db_ids=[actual_vector_db_id],
|
||||||
|
content="sample title",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert_valid_text_response(response)
|
||||||
|
content_text = " ".join([chunk.text for chunk in response.content]).lower()
|
||||||
|
assert "sample title" in content_text or "title" in content_text
|
||||||
|
|
|
@ -178,3 +178,41 @@ def test_content_from_data_and_mime_type_both_encodings_fail():
|
||||||
# Should raise an exception instead of returning empty string
|
# Should raise an exception instead of returning empty string
|
||||||
with pytest.raises(UnicodeDecodeError):
|
with pytest.raises(UnicodeDecodeError):
|
||||||
content_from_data_and_mime_type(data, mime_type)
|
content_from_data_and_mime_type(data, mime_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_memory_tool_error_handling():
|
||||||
|
"""Test that memory tool handles various failures gracefully without crashing."""
|
||||||
|
from llama_stack.providers.inline.tool_runtime.rag.config import RagToolRuntimeConfig
|
||||||
|
from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
|
config = RagToolRuntimeConfig()
|
||||||
|
memory_tool = MemoryToolRuntimeImpl(
|
||||||
|
config=config,
|
||||||
|
vector_io_api=AsyncMock(),
|
||||||
|
inference_api=AsyncMock(),
|
||||||
|
files_api=AsyncMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
docs = [
|
||||||
|
RAGDocument(document_id="good_doc", content="Good content", metadata={}),
|
||||||
|
RAGDocument(document_id="bad_url_doc", content=URL(uri="https://bad.url"), metadata={}),
|
||||||
|
RAGDocument(document_id="another_good_doc", content="Another good content", metadata={}),
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_file1 = MagicMock()
|
||||||
|
mock_file1.id = "file_good1"
|
||||||
|
mock_file2 = MagicMock()
|
||||||
|
mock_file2.id = "file_good2"
|
||||||
|
memory_tool.files_api.openai_upload_file.side_effect = [mock_file1, mock_file2]
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_instance = AsyncMock()
|
||||||
|
mock_instance.get.side_effect = Exception("Bad URL")
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_instance
|
||||||
|
|
||||||
|
# won't raise exception despite one document failing
|
||||||
|
await memory_tool.insert(docs, "vector_store_123")
|
||||||
|
|
||||||
|
# processed 2 documents successfully, skipped 1
|
||||||
|
assert memory_tool.files_api.openai_upload_file.call_count == 2
|
||||||
|
assert memory_tool.vector_io_api.openai_attach_file_to_vector_store.call_count == 2
|
||||||
|
|
|
@ -81,3 +81,58 @@ class TestRagQuery:
|
||||||
# Test that invalid mode raises an error
|
# Test that invalid mode raises an error
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
RAGQueryConfig(mode="wrong_mode")
|
RAGQueryConfig(mode="wrong_mode")
|
||||||
|
|
||||||
|
async def test_query_adds_vector_db_id_to_chunk_metadata(self):
|
||||||
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
|
config=MagicMock(),
|
||||||
|
vector_io_api=MagicMock(),
|
||||||
|
inference_api=MagicMock(),
|
||||||
|
files_api=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_db_ids = ["db1", "db2"]
|
||||||
|
|
||||||
|
# Fake chunks from each DB
|
||||||
|
chunk_metadata1 = ChunkMetadata(
|
||||||
|
document_id="doc1",
|
||||||
|
chunk_id="chunk1",
|
||||||
|
source="test_source1",
|
||||||
|
metadata_token_count=5,
|
||||||
|
)
|
||||||
|
chunk1 = Chunk(
|
||||||
|
content="chunk from db1",
|
||||||
|
metadata={"vector_db_id": "db1", "document_id": "doc1"},
|
||||||
|
stored_chunk_id="c1",
|
||||||
|
chunk_metadata=chunk_metadata1,
|
||||||
|
)
|
||||||
|
|
||||||
|
chunk_metadata2 = ChunkMetadata(
|
||||||
|
document_id="doc2",
|
||||||
|
chunk_id="chunk2",
|
||||||
|
source="test_source2",
|
||||||
|
metadata_token_count=5,
|
||||||
|
)
|
||||||
|
chunk2 = Chunk(
|
||||||
|
content="chunk from db2",
|
||||||
|
metadata={"vector_db_id": "db2", "document_id": "doc2"},
|
||||||
|
stored_chunk_id="c2",
|
||||||
|
chunk_metadata=chunk_metadata2,
|
||||||
|
)
|
||||||
|
|
||||||
|
rag_tool.vector_io_api.query_chunks = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
QueryChunksResponse(chunks=[chunk1], scores=[0.9]),
|
||||||
|
QueryChunksResponse(chunks=[chunk2], scores=[0.8]),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
|
||||||
|
returned_chunks = result.metadata["chunks"]
|
||||||
|
returned_scores = result.metadata["scores"]
|
||||||
|
returned_doc_ids = result.metadata["document_ids"]
|
||||||
|
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
||||||
|
|
||||||
|
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||||
|
assert returned_scores == (0.9, 0.8)
|
||||||
|
assert returned_doc_ids == ["doc1", "doc2"]
|
||||||
|
assert returned_vector_db_ids == ["db1", "db2"]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue