mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Merge branch 'main' into fix/embedding-model-type
This commit is contained in:
commit
309f06829c
59 changed files with 1005 additions and 339 deletions
30
.github/actions/run-and-record-tests/action.yml
vendored
30
.github/actions/run-and-record-tests/action.yml
vendored
|
@ -2,13 +2,6 @@ name: 'Run and Record Tests'
|
||||||
description: 'Run integration tests and handle recording/artifact upload'
|
description: 'Run integration tests and handle recording/artifact upload'
|
||||||
|
|
||||||
inputs:
|
inputs:
|
||||||
test-subdirs:
|
|
||||||
description: 'Comma-separated list of test subdirectories to run'
|
|
||||||
required: true
|
|
||||||
test-pattern:
|
|
||||||
description: 'Regex pattern to pass to pytest -k'
|
|
||||||
required: false
|
|
||||||
default: ''
|
|
||||||
stack-config:
|
stack-config:
|
||||||
description: 'Stack configuration to use'
|
description: 'Stack configuration to use'
|
||||||
required: true
|
required: true
|
||||||
|
@ -18,10 +11,18 @@ inputs:
|
||||||
inference-mode:
|
inference-mode:
|
||||||
description: 'Inference mode (record or replay)'
|
description: 'Inference mode (record or replay)'
|
||||||
required: true
|
required: true
|
||||||
run-vision-tests:
|
test-suite:
|
||||||
description: 'Whether to run vision tests'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
required: false
|
required: false
|
||||||
default: 'false'
|
default: ''
|
||||||
|
test-subdirs:
|
||||||
|
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
test-pattern:
|
||||||
|
description: 'Regex pattern to pass to pytest -k'
|
||||||
|
required: false
|
||||||
|
default: ''
|
||||||
|
|
||||||
runs:
|
runs:
|
||||||
using: 'composite'
|
using: 'composite'
|
||||||
|
@ -42,7 +43,7 @@ runs:
|
||||||
--test-subdirs '${{ inputs.test-subdirs }}' \
|
--test-subdirs '${{ inputs.test-subdirs }}' \
|
||||||
--test-pattern '${{ inputs.test-pattern }}' \
|
--test-pattern '${{ inputs.test-pattern }}' \
|
||||||
--inference-mode '${{ inputs.inference-mode }}' \
|
--inference-mode '${{ inputs.inference-mode }}' \
|
||||||
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \
|
--test-suite '${{ inputs.test-suite }}' \
|
||||||
| tee pytest-${{ inputs.inference-mode }}.log
|
| tee pytest-${{ inputs.inference-mode }}.log
|
||||||
|
|
||||||
|
|
||||||
|
@ -57,12 +58,7 @@ runs:
|
||||||
echo "New recordings detected, committing and pushing"
|
echo "New recordings detected, committing and pushing"
|
||||||
git add tests/integration/recordings/
|
git add tests/integration/recordings/
|
||||||
|
|
||||||
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
|
git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})"
|
||||||
git commit -m "Recordings update from CI (vision)"
|
|
||||||
else
|
|
||||||
git commit -m "Recordings update from CI"
|
|
||||||
fi
|
|
||||||
|
|
||||||
git fetch origin ${{ github.ref_name }}
|
git fetch origin ${{ github.ref_name }}
|
||||||
git rebase origin/${{ github.ref_name }}
|
git rebase origin/${{ github.ref_name }}
|
||||||
echo "Rebased successfully"
|
echo "Rebased successfully"
|
||||||
|
|
8
.github/actions/setup-ollama/action.yml
vendored
8
.github/actions/setup-ollama/action.yml
vendored
|
@ -1,17 +1,17 @@
|
||||||
name: Setup Ollama
|
name: Setup Ollama
|
||||||
description: Start Ollama
|
description: Start Ollama
|
||||||
inputs:
|
inputs:
|
||||||
run-vision-tests:
|
test-suite:
|
||||||
description: 'Run vision tests: "true" or "false"'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
required: false
|
required: false
|
||||||
default: 'false'
|
default: ''
|
||||||
runs:
|
runs:
|
||||||
using: "composite"
|
using: "composite"
|
||||||
steps:
|
steps:
|
||||||
- name: Start Ollama
|
- name: Start Ollama
|
||||||
shell: bash
|
shell: bash
|
||||||
run: |
|
run: |
|
||||||
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
|
if [ "${{ inputs.test-suite }}" == "vision" ]; then
|
||||||
image="ollama-with-vision-model"
|
image="ollama-with-vision-model"
|
||||||
else
|
else
|
||||||
image="ollama-with-models"
|
image="ollama-with-models"
|
||||||
|
|
|
@ -12,10 +12,10 @@ inputs:
|
||||||
description: 'Provider to setup (ollama or vllm)'
|
description: 'Provider to setup (ollama or vllm)'
|
||||||
required: true
|
required: true
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
run-vision-tests:
|
test-suite:
|
||||||
description: 'Whether to setup provider for vision tests'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
required: false
|
required: false
|
||||||
default: 'false'
|
default: ''
|
||||||
inference-mode:
|
inference-mode:
|
||||||
description: 'Inference mode (record or replay)'
|
description: 'Inference mode (record or replay)'
|
||||||
required: true
|
required: true
|
||||||
|
@ -33,7 +33,7 @@ runs:
|
||||||
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
|
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
|
||||||
uses: ./.github/actions/setup-ollama
|
uses: ./.github/actions/setup-ollama
|
||||||
with:
|
with:
|
||||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
test-suite: ${{ inputs.test-suite }}
|
||||||
|
|
||||||
- name: Setup vllm
|
- name: Setup vllm
|
||||||
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}
|
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}
|
||||||
|
|
2
.github/workflows/README.md
vendored
2
.github/workflows/README.md
vendored
|
@ -8,7 +8,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
|
||||||
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
|
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
|
||||||
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
|
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
|
||||||
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
|
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
|
||||||
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode |
|
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
|
||||||
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
|
||||||
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
|
||||||
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |
|
||||||
|
|
20
.github/workflows/integration-tests.yml
vendored
20
.github/workflows/integration-tests.yml
vendored
|
@ -1,6 +1,6 @@
|
||||||
name: Integration Tests (Replay)
|
name: Integration Tests (Replay)
|
||||||
|
|
||||||
run-name: Run the integration test suite from tests/integration in replay mode
|
run-name: Run the integration test suites from tests/integration in replay mode
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
|
@ -32,14 +32,6 @@ on:
|
||||||
description: 'Test against a specific provider'
|
description: 'Test against a specific provider'
|
||||||
type: string
|
type: string
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
test-subdirs:
|
|
||||||
description: 'Comma-separated list of test subdirectories to run'
|
|
||||||
type: string
|
|
||||||
default: ''
|
|
||||||
test-pattern:
|
|
||||||
description: 'Regex pattern to pass to pytest -k'
|
|
||||||
type: string
|
|
||||||
default: ''
|
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
# Skip concurrency for pushes to main - each commit should be tested independently
|
# Skip concurrency for pushes to main - each commit should be tested independently
|
||||||
|
@ -50,7 +42,7 @@ jobs:
|
||||||
|
|
||||||
run-replay-mode-tests:
|
run-replay-mode-tests:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
|
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }}
|
||||||
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
@ -61,7 +53,7 @@ jobs:
|
||||||
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
|
||||||
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
|
||||||
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
|
||||||
run-vision-tests: [true, false]
|
test-suite: [base, vision]
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
@ -73,15 +65,13 @@ jobs:
|
||||||
python-version: ${{ matrix.python-version }}
|
python-version: ${{ matrix.python-version }}
|
||||||
client-version: ${{ matrix.client-version }}
|
client-version: ${{ matrix.client-version }}
|
||||||
provider: ${{ matrix.provider }}
|
provider: ${{ matrix.provider }}
|
||||||
run-vision-tests: ${{ matrix.run-vision-tests }}
|
test-suite: ${{ matrix.test-suite }}
|
||||||
inference-mode: 'replay'
|
inference-mode: 'replay'
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
uses: ./.github/actions/run-and-record-tests
|
uses: ./.github/actions/run-and-record-tests
|
||||||
with:
|
with:
|
||||||
test-subdirs: ${{ inputs.test-subdirs }}
|
|
||||||
test-pattern: ${{ inputs.test-pattern }}
|
|
||||||
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
|
||||||
provider: ${{ matrix.provider }}
|
provider: ${{ matrix.provider }}
|
||||||
inference-mode: 'replay'
|
inference-mode: 'replay'
|
||||||
run-vision-tests: ${{ matrix.run-vision-tests }}
|
test-suite: ${{ matrix.test-suite }}
|
||||||
|
|
32
.github/workflows/record-integration-tests.yml
vendored
32
.github/workflows/record-integration-tests.yml
vendored
|
@ -10,18 +10,18 @@ run-name: Run the integration test suite from tests/integration
|
||||||
on:
|
on:
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
inputs:
|
inputs:
|
||||||
test-subdirs:
|
|
||||||
description: 'Comma-separated list of test subdirectories to run'
|
|
||||||
type: string
|
|
||||||
default: ''
|
|
||||||
test-provider:
|
test-provider:
|
||||||
description: 'Test against a specific provider'
|
description: 'Test against a specific provider'
|
||||||
type: string
|
type: string
|
||||||
default: 'ollama'
|
default: 'ollama'
|
||||||
run-vision-tests:
|
test-suite:
|
||||||
description: 'Whether to run vision tests'
|
description: 'Test suite to use: base, responses, vision, etc.'
|
||||||
type: boolean
|
type: string
|
||||||
default: false
|
default: ''
|
||||||
|
test-subdirs:
|
||||||
|
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
|
||||||
|
type: string
|
||||||
|
default: ''
|
||||||
test-pattern:
|
test-pattern:
|
||||||
description: 'Regex pattern to pass to pytest -k'
|
description: 'Regex pattern to pass to pytest -k'
|
||||||
type: string
|
type: string
|
||||||
|
@ -38,11 +38,11 @@ jobs:
|
||||||
- name: Echo workflow inputs
|
- name: Echo workflow inputs
|
||||||
run: |
|
run: |
|
||||||
echo "::group::Workflow Inputs"
|
echo "::group::Workflow Inputs"
|
||||||
echo "test-subdirs: ${{ inputs.test-subdirs }}"
|
|
||||||
echo "test-provider: ${{ inputs.test-provider }}"
|
|
||||||
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
|
|
||||||
echo "test-pattern: ${{ inputs.test-pattern }}"
|
|
||||||
echo "branch: ${{ github.ref_name }}"
|
echo "branch: ${{ github.ref_name }}"
|
||||||
|
echo "test-provider: ${{ inputs.test-provider }}"
|
||||||
|
echo "test-suite: ${{ inputs.test-suite }}"
|
||||||
|
echo "test-subdirs: ${{ inputs.test-subdirs }}"
|
||||||
|
echo "test-pattern: ${{ inputs.test-pattern }}"
|
||||||
echo "::endgroup::"
|
echo "::endgroup::"
|
||||||
|
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
@ -56,15 +56,15 @@ jobs:
|
||||||
python-version: "3.12" # Use single Python version for recording
|
python-version: "3.12" # Use single Python version for recording
|
||||||
client-version: "latest"
|
client-version: "latest"
|
||||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
test-suite: ${{ inputs.test-suite }}
|
||||||
inference-mode: 'record'
|
inference-mode: 'record'
|
||||||
|
|
||||||
- name: Run and record tests
|
- name: Run and record tests
|
||||||
uses: ./.github/actions/run-and-record-tests
|
uses: ./.github/actions/run-and-record-tests
|
||||||
with:
|
with:
|
||||||
test-pattern: ${{ inputs.test-pattern }}
|
|
||||||
test-subdirs: ${{ inputs.test-subdirs }}
|
|
||||||
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
|
||||||
provider: ${{ inputs.test-provider || 'ollama' }}
|
provider: ${{ inputs.test-provider || 'ollama' }}
|
||||||
inference-mode: 'record'
|
inference-mode: 'record'
|
||||||
run-vision-tests: ${{ inputs.run-vision-tests }}
|
test-suite: ${{ inputs.test-suite }}
|
||||||
|
test-subdirs: ${{ inputs.test-subdirs }}
|
||||||
|
test-pattern: ${{ inputs.test-pattern }}
|
||||||
|
|
|
@ -86,7 +86,7 @@ repos:
|
||||||
language: python
|
language: python
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
require_serial: true
|
require_serial: true
|
||||||
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
files: ^llama_stack/distributions/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
|
||||||
- id: provider-codegen
|
- id: provider-codegen
|
||||||
name: Provider Codegen
|
name: Provider Codegen
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
|
|
|
@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo
|
||||||
apis:
|
apis:
|
||||||
- agents
|
- agents
|
||||||
- inference
|
- inference
|
||||||
|
- safety
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
- vector_io
|
- vector_io
|
||||||
|
@ -30,6 +31,11 @@ providers:
|
||||||
db: ${env.POSTGRES_DB:=llamastack}
|
db: ${env.POSTGRES_DB:=llamastack}
|
||||||
user: ${env.POSTGRES_USER:=llamastack}
|
user: ${env.POSTGRES_USER:=llamastack}
|
||||||
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
password: ${env.POSTGRES_PASSWORD:=llamastack}
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
agents:
|
agents:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -95,6 +101,8 @@ models:
|
||||||
- model_id: ${env.INFERENCE_MODEL}
|
- model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: vllm-inference
|
provider_id: vllm-inference
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
shields:
|
||||||
|
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
|
||||||
vector_dbs: []
|
vector_dbs: []
|
||||||
datasets: []
|
datasets: []
|
||||||
scoring_fns: []
|
scoring_fns: []
|
||||||
|
|
|
@ -18,12 +18,13 @@ embedding_model_id = (
|
||||||
).identifier
|
).identifier
|
||||||
embedding_dimension = em.metadata["embedding_dimension"]
|
embedding_dimension = em.metadata["embedding_dimension"]
|
||||||
|
|
||||||
_ = client.vector_dbs.register(
|
vector_db = client.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
)
|
)
|
||||||
|
vector_db_id = vector_db.identifier
|
||||||
source = "https://www.paulgraham.com/greatwork.html"
|
source = "https://www.paulgraham.com/greatwork.html"
|
||||||
print("rag_tool> Ingesting document:", source)
|
print("rag_tool> Ingesting document:", source)
|
||||||
document = RAGDocument(
|
document = RAGDocument(
|
||||||
|
@ -35,7 +36,7 @@ document = RAGDocument(
|
||||||
client.tool_runtime.rag_tool.insert(
|
client.tool_runtime.rag_tool.insert(
|
||||||
documents=[document],
|
documents=[document],
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
chunk_size_in_tokens=50,
|
chunk_size_in_tokens=100,
|
||||||
)
|
)
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
client,
|
client,
|
||||||
|
|
|
@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
|
||||||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
||||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
||||||
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
||||||
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
||||||
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
||||||
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
|
||||||
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
|
||||||
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
|
||||||
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
|
||||||
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
|
||||||
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
|
||||||
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
|
||||||
|
|
||||||
## Sample Configuration
|
## Sample Configuration
|
||||||
|
|
|
@ -527,7 +527,7 @@ class InferenceRouter(Inference):
|
||||||
|
|
||||||
# Store the response with the ID that will be returned to the client
|
# Store the response with the ID that will be returned to the client
|
||||||
if self.store:
|
if self.store:
|
||||||
await self.store.store_chat_completion(response, messages)
|
asyncio.create_task(self.store.store_chat_completion(response, messages))
|
||||||
|
|
||||||
if self.telemetry:
|
if self.telemetry:
|
||||||
metrics = self._construct_metrics(
|
metrics = self._construct_metrics(
|
||||||
|
@ -855,4 +855,4 @@ class InferenceRouter(Inference):
|
||||||
object="chat.completion",
|
object="chat.completion",
|
||||||
)
|
)
|
||||||
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
logger.debug(f"InferenceRouter.completion_response: {final_response}")
|
||||||
await self.store.store_chat_completion(final_response, messages)
|
asyncio.create_task(self.store.store_chat_completion(final_response, messages))
|
||||||
|
|
|
@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
provider_vector_db_id: str | None = None,
|
provider_vector_db_id: str | None = None,
|
||||||
vector_db_name: str | None = None,
|
vector_db_name: str | None = None,
|
||||||
) -> VectorDB:
|
) -> VectorDB:
|
||||||
provider_vector_db_id = provider_vector_db_id or vector_db_id
|
|
||||||
if provider_id is None:
|
if provider_id is None:
|
||||||
if len(self.impls_by_provider_id) > 0:
|
if len(self.impls_by_provider_id) > 0:
|
||||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||||
|
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
||||||
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
|
||||||
if "embedding_dimension" not in model.metadata:
|
if "embedding_dimension" not in model.metadata:
|
||||||
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
|
||||||
|
|
||||||
|
provider = self.impls_by_provider_id[provider_id]
|
||||||
|
logger.warning(
|
||||||
|
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
|
||||||
|
)
|
||||||
|
vector_store = await provider.openai_create_vector_store(
|
||||||
|
name=vector_db_name or vector_db_id,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
embedding_dimension=model.metadata["embedding_dimension"],
|
||||||
|
provider_id=provider_id,
|
||||||
|
provider_vector_db_id=provider_vector_db_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store_id = vector_store.id
|
||||||
|
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
|
||||||
|
logger.warning(
|
||||||
|
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
|
||||||
|
)
|
||||||
|
|
||||||
vector_db_data = {
|
vector_db_data = {
|
||||||
"identifier": vector_db_id,
|
"identifier": vector_store_id,
|
||||||
"type": ResourceType.vector_db.value,
|
"type": ResourceType.vector_db.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": provider_vector_db_id,
|
"provider_resource_id": actual_provider_vector_db_id,
|
||||||
"embedding_model": embedding_model,
|
"embedding_model": embedding_model,
|
||||||
"embedding_dimension": model.metadata["embedding_dimension"],
|
"embedding_dimension": model.metadata["embedding_dimension"],
|
||||||
"vector_db_name": vector_db_name,
|
"vector_db_name": vector_store.name,
|
||||||
}
|
}
|
||||||
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
|
||||||
await self.register_object(vector_db)
|
await self.register_object(vector_db)
|
||||||
|
|
|
@ -132,15 +132,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
elif isinstance(exc, ConflictError):
|
elif isinstance(exc, ConflictError):
|
||||||
return HTTPException(status_code=409, detail=str(exc))
|
return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
|
||||||
elif isinstance(exc, ResourceNotFoundError):
|
elif isinstance(exc, ResourceNotFoundError):
|
||||||
return HTTPException(status_code=404, detail=str(exc))
|
return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
|
||||||
elif isinstance(exc, ValueError):
|
elif isinstance(exc, ValueError):
|
||||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
|
||||||
elif isinstance(exc, BadRequestError):
|
elif isinstance(exc, BadRequestError):
|
||||||
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
|
||||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||||
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
|
||||||
|
elif isinstance(exc, ConnectionError | httpx.ConnectError):
|
||||||
|
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
|
||||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||||
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
|
||||||
elif isinstance(exc, NotImplementedError):
|
elif isinstance(exc, NotImplementedError):
|
||||||
|
|
|
@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
template = get_starter_distribution_template()
|
template = get_starter_distribution_template(name="ci-tests")
|
||||||
name = "ci-tests"
|
|
||||||
template.name = name
|
|
||||||
template.description = "CI tests for Llama Stack"
|
template.description = "CI tests for Llama Stack"
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
|
@ -89,28 +89,28 @@ providers:
|
||||||
config:
|
config:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db
|
||||||
- provider_id: sqlite-vec
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::sqlite-vec
|
provider_type: inline::sqlite-vec
|
||||||
config:
|
config:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db
|
||||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||||
provider_type: inline::milvus
|
provider_type: inline::milvus
|
||||||
config:
|
config:
|
||||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db
|
||||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
config:
|
config:
|
||||||
url: ${env.CHROMADB_URL:=}
|
url: ${env.CHROMADB_URL:=}
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db
|
||||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||||
provider_type: remote::pgvector
|
provider_type: remote::pgvector
|
||||||
config:
|
config:
|
||||||
|
@ -121,15 +121,15 @@ providers:
|
||||||
password: ${env.PGVECTOR_PASSWORD:=}
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
config:
|
config:
|
||||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -89,28 +89,28 @@ providers:
|
||||||
config:
|
config:
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db
|
||||||
- provider_id: sqlite-vec
|
- provider_id: sqlite-vec
|
||||||
provider_type: inline::sqlite-vec
|
provider_type: inline::sqlite-vec
|
||||||
config:
|
config:
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db
|
||||||
- provider_id: ${env.MILVUS_URL:+milvus}
|
- provider_id: ${env.MILVUS_URL:+milvus}
|
||||||
provider_type: inline::milvus
|
provider_type: inline::milvus
|
||||||
config:
|
config:
|
||||||
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db
|
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db
|
||||||
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
- provider_id: ${env.CHROMADB_URL:+chromadb}
|
||||||
provider_type: remote::chromadb
|
provider_type: remote::chromadb
|
||||||
config:
|
config:
|
||||||
url: ${env.CHROMADB_URL:=}
|
url: ${env.CHROMADB_URL:=}
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db
|
||||||
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
- provider_id: ${env.PGVECTOR_DB:+pgvector}
|
||||||
provider_type: remote::pgvector
|
provider_type: remote::pgvector
|
||||||
config:
|
config:
|
||||||
|
@ -121,15 +121,15 @@ providers:
|
||||||
password: ${env.PGVECTOR_PASSWORD:=}
|
password: ${env.PGVECTOR_PASSWORD:=}
|
||||||
kvstore:
|
kvstore:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
config:
|
config:
|
||||||
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
|
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
|
||||||
metadata_store:
|
metadata_store:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
|
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
|
|
@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
template = get_starter_distribution_template()
|
template = get_starter_distribution_template(name="starter-gpu")
|
||||||
name = "starter-gpu"
|
|
||||||
template.name = name
|
|
||||||
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
|
||||||
|
|
||||||
template.providers["post_training"] = [
|
template.providers["post_training"] = [
|
||||||
|
|
|
@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]:
|
||||||
return inference_providers
|
return inference_providers
|
||||||
|
|
||||||
|
|
||||||
def get_distribution_template() -> DistributionTemplate:
|
def get_distribution_template(name: str = "starter") -> DistributionTemplate:
|
||||||
remote_inference_providers = get_remote_inference_providers()
|
remote_inference_providers = get_remote_inference_providers()
|
||||||
name = "starter"
|
|
||||||
|
|
||||||
providers = {
|
providers = {
|
||||||
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]
|
||||||
|
|
|
@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
# TODO: set expiration time for garbage collection
|
# TODO: set expiration time for garbage collection
|
||||||
|
|
||||||
if endpoint not in ["/v1/chat/completions"]:
|
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
|
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
|
||||||
)
|
)
|
||||||
|
|
||||||
if completion_window != "24h":
|
if completion_window != "24h":
|
||||||
|
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
|
||||||
)
|
)
|
||||||
valid = False
|
valid = False
|
||||||
|
|
||||||
for param, expected_type, type_string in [
|
if batch.endpoint == "/v1/chat/completions":
|
||||||
("model", str, "a string"),
|
required_params = [
|
||||||
# messages is specific to /v1/chat/completions
|
("model", str, "a string"),
|
||||||
# we could skip validating messages here and let inference fail. however,
|
# messages is specific to /v1/chat/completions
|
||||||
# that would be a very expensive way to find out messages is wrong.
|
# we could skip validating messages here and let inference fail. however,
|
||||||
("messages", list, "an array"), # TODO: allow messages to be a string?
|
# that would be a very expensive way to find out messages is wrong.
|
||||||
]:
|
("messages", list, "an array"), # TODO: allow messages to be a string?
|
||||||
|
]
|
||||||
|
else: # /v1/completions
|
||||||
|
required_params = [
|
||||||
|
("model", str, "a string"),
|
||||||
|
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
|
||||||
|
]
|
||||||
|
|
||||||
|
for param, expected_type, type_string in required_params:
|
||||||
if param not in body:
|
if param not in body:
|
||||||
errors.append(
|
errors.append(
|
||||||
BatchError(
|
BatchError(
|
||||||
|
@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# TODO(SECURITY): review body for security issues
|
# TODO(SECURITY): review body for security issues
|
||||||
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
if request.url == "/v1/chat/completions":
|
||||||
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
|
||||||
|
chat_response = await self.inference_api.openai_chat_completion(**request.body)
|
||||||
|
|
||||||
# this is for mypy, we don't allow streaming so we'll get the right type
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
|
||||||
return {
|
return {
|
||||||
"id": request_id,
|
"id": request_id,
|
||||||
"custom_id": request.custom_id,
|
"custom_id": request.custom_id,
|
||||||
"response": {
|
"response": {
|
||||||
"status_code": 200,
|
"status_code": 200,
|
||||||
"request_id": request_id, # TODO: should this be different?
|
"request_id": request_id, # TODO: should this be different?
|
||||||
"body": chat_response.model_dump_json(),
|
"body": chat_response.model_dump_json(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
else: # /v1/completions
|
||||||
|
completion_response = await self.inference_api.openai_completion(**request.body)
|
||||||
|
|
||||||
|
# this is for mypy, we don't allow streaming so we'll get the right type
|
||||||
|
assert hasattr(completion_response, "model_dump_json"), (
|
||||||
|
"Completion response must have model_dump_json method"
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"id": request_id,
|
||||||
|
"custom_id": request.custom_id,
|
||||||
|
"response": {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": request_id,
|
||||||
|
"body": completion_response.model_dump_json(),
|
||||||
|
},
|
||||||
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
|
||||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||||
from .memory import MemoryToolRuntimeImpl
|
from .memory import MemoryToolRuntimeImpl
|
||||||
|
|
||||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference])
|
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -5,10 +5,15 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import io
|
||||||
|
import mimetypes
|
||||||
import secrets
|
import secrets
|
||||||
import string
|
import string
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import UploadFile
|
||||||
from pydantic import TypeAdapter
|
from pydantic import TypeAdapter
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
InterleavedContentItem,
|
InterleavedContentItem,
|
||||||
TextContentItem,
|
TextContentItem,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||||
from llama_stack.apis.inference import Inference
|
from llama_stack.apis.inference import Inference
|
||||||
from llama_stack.apis.tools import (
|
from llama_stack.apis.tools import (
|
||||||
ListToolDefsResponse,
|
ListToolDefsResponse,
|
||||||
|
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
|
||||||
ToolParameter,
|
ToolParameter,
|
||||||
ToolRuntime,
|
ToolRuntime,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import (
|
||||||
|
QueryChunksResponse,
|
||||||
|
VectorIO,
|
||||||
|
VectorStoreChunkingStrategyStatic,
|
||||||
|
VectorStoreChunkingStrategyStaticConfig,
|
||||||
|
)
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
make_overlapped_chunks,
|
parse_data_url,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import RagToolRuntimeConfig
|
from .config import RagToolRuntimeConfig
|
||||||
|
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
config: RagToolRuntimeConfig,
|
config: RagToolRuntimeConfig,
|
||||||
vector_io_api: VectorIO,
|
vector_io_api: VectorIO,
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
|
files_api: Files,
|
||||||
):
|
):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
|
self.files_api = files_api
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
||||||
vector_db_id: str,
|
vector_db_id: str,
|
||||||
chunk_size_in_tokens: int = 512,
|
chunk_size_in_tokens: int = 512,
|
||||||
) -> None:
|
) -> None:
|
||||||
chunks = []
|
if not documents:
|
||||||
|
return
|
||||||
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
if isinstance(doc.content, URL):
|
||||||
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
|
if doc.content.uri.startswith("data:"):
|
||||||
chunks.extend(
|
parts = parse_data_url(doc.content.uri)
|
||||||
make_overlapped_chunks(
|
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
|
||||||
doc.document_id,
|
mime_type = parts["mimetype"]
|
||||||
content,
|
else:
|
||||||
chunk_size_in_tokens,
|
async with httpx.AsyncClient() as client:
|
||||||
chunk_size_in_tokens // 4,
|
response = await client.get(doc.content.uri)
|
||||||
doc.metadata,
|
file_data = response.content
|
||||||
|
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
|
||||||
|
else:
|
||||||
|
content_str = await content_from_doc(doc)
|
||||||
|
file_data = content_str.encode("utf-8")
|
||||||
|
mime_type = doc.mime_type or "text/plain"
|
||||||
|
|
||||||
|
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
|
||||||
|
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
|
||||||
|
|
||||||
|
file_obj = io.BytesIO(file_data)
|
||||||
|
file_obj.name = filename
|
||||||
|
|
||||||
|
upload_file = UploadFile(file=file_obj, filename=filename)
|
||||||
|
|
||||||
|
created_file = await self.files_api.openai_upload_file(
|
||||||
|
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
|
||||||
|
)
|
||||||
|
|
||||||
|
chunking_strategy = VectorStoreChunkingStrategyStatic(
|
||||||
|
static=VectorStoreChunkingStrategyStaticConfig(
|
||||||
|
max_chunk_size_tokens=chunk_size_in_tokens,
|
||||||
|
chunk_overlap_tokens=chunk_size_in_tokens // 4,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not chunks:
|
await self.vector_io_api.openai_attach_file_to_vector_store(
|
||||||
return
|
vector_store_id=vector_db_id,
|
||||||
|
file_id=created_file.id,
|
||||||
await self.vector_io_api.insert_chunks(
|
attributes=doc.metadata,
|
||||||
chunks=chunks,
|
chunking_strategy=chunking_strategy,
|
||||||
vector_db_id=vector_db_id,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
async def query(
|
async def query(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="fireworks",
|
adapter_type="fireworks",
|
||||||
pip_packages=[
|
pip_packages=[
|
||||||
"fireworks-ai<=0.18.0",
|
"fireworks-ai<=0.17.16",
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.remote.inference.fireworks",
|
module="llama_stack.providers.remote.inference.fireworks",
|
||||||
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
|
||||||
|
@ -207,7 +207,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="gemini",
|
adapter_type="gemini",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm", "openai"],
|
||||||
module="llama_stack.providers.remote.inference.gemini",
|
module="llama_stack.providers.remote.inference.gemini",
|
||||||
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
|
||||||
|
@ -270,7 +270,7 @@ Available Models:
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
adapter_type="sambanova",
|
adapter_type="sambanova",
|
||||||
pip_packages=["litellm"],
|
pip_packages=["litellm", "openai"],
|
||||||
module="llama_stack.providers.remote.inference.sambanova",
|
module="llama_stack.providers.remote.inference.sambanova",
|
||||||
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
|
||||||
|
|
|
@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
|
||||||
],
|
],
|
||||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||||
api_dependencies=[Api.vector_io, Api.inference],
|
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
|
|
|
@ -5,12 +5,13 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import GeminiConfig
|
from .config import GeminiConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
def __init__(self, config: GeminiConfig) -> None:
|
def __init__(self, config: GeminiConfig) -> None:
|
||||||
LiteLLMOpenAIMixin.__init__(
|
LiteLLMOpenAIMixin.__init__(
|
||||||
self,
|
self,
|
||||||
|
@ -21,6 +22,11 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
)
|
)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
|
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||||
|
|
||||||
|
def get_base_url(self):
|
||||||
|
return "https://generativelanguage.googleapis.com/v1beta/openai/"
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
await super().initialize()
|
await super().initialize()
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||||
|
|
||||||
from .config import SambaNovaImplConfig
|
from .config import SambaNovaImplConfig
|
||||||
from .models import MODEL_ENTRIES
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
|
||||||
|
"""
|
||||||
|
SambaNova Inference Adapter for Llama Stack.
|
||||||
|
|
||||||
|
Note: The inheritance order is important here. OpenAIMixin must come before
|
||||||
|
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
|
||||||
|
is used instead of LiteLLMOpenAIMixin.check_model_availability().
|
||||||
|
|
||||||
|
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
|
||||||
|
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: SambaNovaImplConfig):
|
def __init__(self, config: SambaNovaImplConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.environment_available_models = []
|
self.environment_available_models = []
|
||||||
|
@ -24,3 +37,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
download_images=True, # SambaNova requires base64 image encoding
|
download_images=True, # SambaNova requires base64 image encoding
|
||||||
json_schema_strict=False, # SambaNova doesn't support strict=True yet
|
json_schema_strict=False, # SambaNova doesn't support strict=True yet
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
|
||||||
|
get_api_key = LiteLLMOpenAIMixin.get_api_key
|
||||||
|
|
||||||
|
def get_base_url(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the base URL for OpenAI mixin.
|
||||||
|
|
||||||
|
:return: The SambaNova base URL
|
||||||
|
"""
|
||||||
|
return self.config.url
|
||||||
|
|
|
@ -4,53 +4,55 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class BedrockBaseConfig(BaseModel):
|
class BedrockBaseConfig(BaseModel):
|
||||||
aws_access_key_id: str | None = Field(
|
aws_access_key_id: str | None = Field(
|
||||||
default=None,
|
default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
|
||||||
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
|
||||||
)
|
)
|
||||||
aws_secret_access_key: str | None = Field(
|
aws_secret_access_key: str | None = Field(
|
||||||
default=None,
|
default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
|
||||||
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
|
||||||
)
|
)
|
||||||
aws_session_token: str | None = Field(
|
aws_session_token: str | None = Field(
|
||||||
default=None,
|
default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
|
||||||
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
|
||||||
)
|
)
|
||||||
region_name: str | None = Field(
|
region_name: str | None = Field(
|
||||||
default=None,
|
default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
|
||||||
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
description="The default AWS Region to use, for example, us-west-1 or us-west-2."
|
||||||
"Default use environment variable: AWS_DEFAULT_REGION",
|
"Default use environment variable: AWS_DEFAULT_REGION",
|
||||||
)
|
)
|
||||||
profile_name: str | None = Field(
|
profile_name: str | None = Field(
|
||||||
default=None,
|
default_factory=lambda: os.getenv("AWS_PROFILE"),
|
||||||
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
|
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
|
||||||
)
|
)
|
||||||
total_max_attempts: int | None = Field(
|
total_max_attempts: int | None = Field(
|
||||||
default=None,
|
default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
|
||||||
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
description="An integer representing the maximum number of attempts that will be made for a single request, "
|
||||||
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
|
||||||
)
|
)
|
||||||
retry_mode: str | None = Field(
|
retry_mode: str | None = Field(
|
||||||
default=None,
|
default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
|
||||||
description="A string representing the type of retries Boto3 will perform."
|
description="A string representing the type of retries Boto3 will perform."
|
||||||
"Default use environment variable: AWS_RETRY_MODE",
|
"Default use environment variable: AWS_RETRY_MODE",
|
||||||
)
|
)
|
||||||
connect_timeout: float | None = Field(
|
connect_timeout: float | None = Field(
|
||||||
default=60,
|
default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
|
||||||
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
|
||||||
"The default is 60 seconds.",
|
"The default is 60 seconds.",
|
||||||
)
|
)
|
||||||
read_timeout: float | None = Field(
|
read_timeout: float | None = Field(
|
||||||
default=60,
|
default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
|
||||||
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
|
||||||
"The default is 60 seconds.",
|
"The default is 60 seconds.",
|
||||||
)
|
)
|
||||||
session_ttl: int | None = Field(
|
session_ttl: int | None = Field(
|
||||||
default=3600,
|
default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
|
||||||
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import struct
|
import struct
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
task_type: EmbeddingTaskType | None = None,
|
task_type: EmbeddingTaskType | None = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
model = await self.model_store.get_model(model_id)
|
model = await self.model_store.get_model(model_id)
|
||||||
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id)
|
embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(
|
embeddings = await asyncio.to_thread(
|
||||||
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False
|
embedding_model.encode,
|
||||||
|
[interleaved_content_as_str(content) for content in contents],
|
||||||
|
show_progress_bar=False,
|
||||||
)
|
)
|
||||||
return EmbeddingsResponse(embeddings=embeddings)
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
|
|
||||||
# Get the model and generate embeddings
|
# Get the model and generate embeddings
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
|
||||||
embeddings = embedding_model.encode(input_list, show_progress_bar=False)
|
embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
|
||||||
|
|
||||||
# Convert embeddings to the requested format
|
# Convert embeddings to the requested format
|
||||||
data = []
|
data = []
|
||||||
|
@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
usage=usage,
|
usage=usage,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
global EMBEDDING_MODELS
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin:
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
def _load_model():
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
return SentenceTransformer(model)
|
||||||
|
|
||||||
|
loaded_model = await asyncio.to_thread(_load_model)
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
|
@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
|
||||||
raise AuthenticationRequiredError(exc) from exc
|
raise AuthenticationRequiredError(exc) from exc
|
||||||
if i == len(connection_strategies) - 1:
|
if i == len(connection_strategies) - 1:
|
||||||
raise
|
raise
|
||||||
|
except* httpx.ConnectError as eg:
|
||||||
|
# Connection refused, server down, network unreachable
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
|
||||||
|
logger.error(f"MCP connection error: {error_msg}")
|
||||||
|
raise ConnectionError(error_msg) from eg
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
|
except* httpx.TimeoutException as eg:
|
||||||
|
# Request timeout, server too slow
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
error_msg = f"MCP server at {endpoint} timed out"
|
||||||
|
logger.error(f"MCP timeout error: {error_msg}")
|
||||||
|
raise TimeoutError(error_msg) from eg
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
|
except* httpx.RequestError as eg:
|
||||||
|
# DNS resolution failures, network errors, invalid URLs
|
||||||
|
if i == len(connection_strategies) - 1:
|
||||||
|
# Get the first exception's message for the error string
|
||||||
|
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
|
||||||
|
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
|
||||||
|
logger.error(f"MCP network error: {error_msg}")
|
||||||
|
raise ConnectionError(error_msg) from eg
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
|
||||||
|
)
|
||||||
except* McpError:
|
except* McpError:
|
||||||
if i < len(connection_strategies) - 1:
|
if i < len(connection_strategies) - 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -15,7 +15,7 @@ set -euo pipefail
|
||||||
BRANCH=""
|
BRANCH=""
|
||||||
TEST_SUBDIRS=""
|
TEST_SUBDIRS=""
|
||||||
TEST_PROVIDER="ollama"
|
TEST_PROVIDER="ollama"
|
||||||
RUN_VISION_TESTS=false
|
TEST_SUITE="base"
|
||||||
TEST_PATTERN=""
|
TEST_PATTERN=""
|
||||||
|
|
||||||
# Help function
|
# Help function
|
||||||
|
@ -27,9 +27,9 @@ Trigger the integration test recording workflow remotely. This way you do not ne
|
||||||
|
|
||||||
OPTIONS:
|
OPTIONS:
|
||||||
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
|
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
|
||||||
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (REQUIRED)
|
|
||||||
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama)
|
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama)
|
||||||
-v, --run-vision-tests Include vision tests in the recording
|
-t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
|
||||||
|
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
|
||||||
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k
|
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k
|
||||||
-h, --help Show this help message
|
-h, --help Show this help message
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ EXAMPLES:
|
||||||
$0 --test-subdirs "agents"
|
$0 --test-subdirs "agents"
|
||||||
|
|
||||||
# Record tests for specific branch with vision tests
|
# Record tests for specific branch with vision tests
|
||||||
$0 -b my-feature-branch --test-subdirs "inference" --run-vision-tests
|
$0 -b my-feature-branch --test-suite vision
|
||||||
|
|
||||||
# Record multiple test subdirectories with specific provider
|
# Record multiple test subdirectories with specific provider
|
||||||
$0 --test-subdirs "agents,inference" --test-provider vllm
|
$0 --test-subdirs "agents,inference" --test-provider vllm
|
||||||
|
@ -71,9 +71,9 @@ while [[ $# -gt 0 ]]; do
|
||||||
TEST_PROVIDER="$2"
|
TEST_PROVIDER="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
-v|--run-vision-tests)
|
-t|--test-suite)
|
||||||
RUN_VISION_TESTS=true
|
TEST_SUITE="$2"
|
||||||
shift
|
shift 2
|
||||||
;;
|
;;
|
||||||
-k|--test-pattern)
|
-k|--test-pattern)
|
||||||
TEST_PATTERN="$2"
|
TEST_PATTERN="$2"
|
||||||
|
@ -92,11 +92,11 @@ while [[ $# -gt 0 ]]; do
|
||||||
done
|
done
|
||||||
|
|
||||||
# Validate required parameters
|
# Validate required parameters
|
||||||
if [[ -z "$TEST_SUBDIRS" ]]; then
|
if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
|
||||||
echo "Error: --test-subdirs is required"
|
echo "Error: --test-subdirs or --test-suite is required"
|
||||||
echo "Please specify which test subdirectories to run, e.g.:"
|
echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
|
||||||
echo " $0 --test-subdirs \"agents,inference\""
|
echo " $0 --test-subdirs \"agents,inference\""
|
||||||
echo " $0 --test-subdirs \"inference\" --run-vision-tests"
|
echo " $0 --test-suite vision"
|
||||||
echo ""
|
echo ""
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
@ -239,17 +239,19 @@ echo "Triggering integration test recording workflow..."
|
||||||
echo "Branch: $BRANCH"
|
echo "Branch: $BRANCH"
|
||||||
echo "Test provider: $TEST_PROVIDER"
|
echo "Test provider: $TEST_PROVIDER"
|
||||||
echo "Test subdirs: $TEST_SUBDIRS"
|
echo "Test subdirs: $TEST_SUBDIRS"
|
||||||
echo "Run vision tests: $RUN_VISION_TESTS"
|
echo "Test suite: $TEST_SUITE"
|
||||||
echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
|
echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Prepare inputs for gh workflow run
|
# Prepare inputs for gh workflow run
|
||||||
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
|
if [[ -n "$TEST_SUBDIRS" ]]; then
|
||||||
|
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
|
||||||
|
fi
|
||||||
if [[ -n "$TEST_PROVIDER" ]]; then
|
if [[ -n "$TEST_PROVIDER" ]]; then
|
||||||
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'"
|
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'"
|
||||||
fi
|
fi
|
||||||
if [[ "$RUN_VISION_TESTS" == "true" ]]; then
|
if [[ -n "$TEST_SUITE" ]]; then
|
||||||
INPUTS="$INPUTS -f run-vision-tests=true"
|
INPUTS="$INPUTS -f test-suite='$TEST_SUITE'"
|
||||||
fi
|
fi
|
||||||
if [[ -n "$TEST_PATTERN" ]]; then
|
if [[ -n "$TEST_PATTERN" ]]; then
|
||||||
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'"
|
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'"
|
||||||
|
|
|
@ -16,7 +16,7 @@ STACK_CONFIG=""
|
||||||
PROVIDER=""
|
PROVIDER=""
|
||||||
TEST_SUBDIRS=""
|
TEST_SUBDIRS=""
|
||||||
TEST_PATTERN=""
|
TEST_PATTERN=""
|
||||||
RUN_VISION_TESTS="false"
|
TEST_SUITE="base"
|
||||||
INFERENCE_MODE="replay"
|
INFERENCE_MODE="replay"
|
||||||
EXTRA_PARAMS=""
|
EXTRA_PARAMS=""
|
||||||
|
|
||||||
|
@ -28,12 +28,16 @@ Usage: $0 [OPTIONS]
|
||||||
Options:
|
Options:
|
||||||
--stack-config STRING Stack configuration to use (required)
|
--stack-config STRING Stack configuration to use (required)
|
||||||
--provider STRING Provider to use (ollama, vllm, etc.) (required)
|
--provider STRING Provider to use (ollama, vllm, etc.) (required)
|
||||||
--test-subdirs STRING Comma-separated list of test subdirectories to run (default: 'inference')
|
--test-suite STRING Comma-separated list of test suites to run (default: 'base')
|
||||||
--run-vision-tests Run vision tests instead of regular tests
|
|
||||||
--inference-mode STRING Inference mode: record or replay (default: replay)
|
--inference-mode STRING Inference mode: record or replay (default: replay)
|
||||||
|
--test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
|
||||||
--test-pattern STRING Regex pattern to pass to pytest -k
|
--test-pattern STRING Regex pattern to pass to pytest -k
|
||||||
--help Show this help message
|
--help Show this help message
|
||||||
|
|
||||||
|
Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options.
|
||||||
|
|
||||||
|
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
# Basic inference tests with ollama
|
# Basic inference tests with ollama
|
||||||
$0 --stack-config server:ci-tests --provider ollama
|
$0 --stack-config server:ci-tests --provider ollama
|
||||||
|
@ -42,7 +46,7 @@ Examples:
|
||||||
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents'
|
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents'
|
||||||
|
|
||||||
# Vision tests with ollama
|
# Vision tests with ollama
|
||||||
$0 --stack-config server:ci-tests --provider ollama --run-vision-tests
|
$0 --stack-config server:ci-tests --provider ollama --test-suite vision
|
||||||
|
|
||||||
# Record mode for updating test recordings
|
# Record mode for updating test recordings
|
||||||
$0 --stack-config server:ci-tests --provider ollama --inference-mode record
|
$0 --stack-config server:ci-tests --provider ollama --inference-mode record
|
||||||
|
@ -64,9 +68,9 @@ while [[ $# -gt 0 ]]; do
|
||||||
TEST_SUBDIRS="$2"
|
TEST_SUBDIRS="$2"
|
||||||
shift 2
|
shift 2
|
||||||
;;
|
;;
|
||||||
--run-vision-tests)
|
--test-suite)
|
||||||
RUN_VISION_TESTS="true"
|
TEST_SUITE="$2"
|
||||||
shift
|
shift 2
|
||||||
;;
|
;;
|
||||||
--inference-mode)
|
--inference-mode)
|
||||||
INFERENCE_MODE="$2"
|
INFERENCE_MODE="$2"
|
||||||
|
@ -92,22 +96,25 @@ done
|
||||||
# Validate required parameters
|
# Validate required parameters
|
||||||
if [[ -z "$STACK_CONFIG" ]]; then
|
if [[ -z "$STACK_CONFIG" ]]; then
|
||||||
echo "Error: --stack-config is required"
|
echo "Error: --stack-config is required"
|
||||||
usage
|
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ -z "$PROVIDER" ]]; then
|
if [[ -z "$PROVIDER" ]]; then
|
||||||
echo "Error: --provider is required"
|
echo "Error: --provider is required"
|
||||||
usage
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -z "$TEST_SUITE" && -z "$TEST_SUBDIRS" ]]; then
|
||||||
|
echo "Error: --test-suite or --test-subdirs is required"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo "=== Llama Stack Integration Test Runner ==="
|
echo "=== Llama Stack Integration Test Runner ==="
|
||||||
echo "Stack Config: $STACK_CONFIG"
|
echo "Stack Config: $STACK_CONFIG"
|
||||||
echo "Provider: $PROVIDER"
|
echo "Provider: $PROVIDER"
|
||||||
echo "Test Subdirs: $TEST_SUBDIRS"
|
|
||||||
echo "Vision Tests: $RUN_VISION_TESTS"
|
|
||||||
echo "Inference Mode: $INFERENCE_MODE"
|
echo "Inference Mode: $INFERENCE_MODE"
|
||||||
|
echo "Test Suite: $TEST_SUITE"
|
||||||
|
echo "Test Subdirs: $TEST_SUBDIRS"
|
||||||
echo "Test Pattern: $TEST_PATTERN"
|
echo "Test Pattern: $TEST_PATTERN"
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
|
@ -194,84 +201,46 @@ if [[ -n "$TEST_PATTERN" ]]; then
|
||||||
PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN"
|
PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Run vision tests if specified
|
|
||||||
if [[ "$RUN_VISION_TESTS" == "true" ]]; then
|
|
||||||
echo "Running vision tests..."
|
|
||||||
set +e
|
|
||||||
pytest -s -v tests/integration/inference/test_vision_inference.py \
|
|
||||||
--stack-config="$STACK_CONFIG" \
|
|
||||||
-k "$PYTEST_PATTERN" \
|
|
||||||
--vision-model=ollama/llama3.2-vision:11b \
|
|
||||||
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
|
|
||||||
--color=yes $EXTRA_PARAMS \
|
|
||||||
--capture=tee-sys
|
|
||||||
exit_code=$?
|
|
||||||
set -e
|
|
||||||
|
|
||||||
if [ $exit_code -eq 0 ]; then
|
|
||||||
echo "✅ Vision tests completed successfully"
|
|
||||||
elif [ $exit_code -eq 5 ]; then
|
|
||||||
echo "⚠️ No vision tests collected (pattern matched no tests)"
|
|
||||||
else
|
|
||||||
echo "❌ Vision tests failed"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
|
|
||||||
# Run regular tests
|
|
||||||
if [[ -z "$TEST_SUBDIRS" ]]; then
|
|
||||||
TEST_SUBDIRS=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
|
|
||||||
sed 's|tests/integration/||' |
|
|
||||||
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
|
|
||||||
sort)
|
|
||||||
fi
|
|
||||||
echo "Test subdirs to run: $TEST_SUBDIRS"
|
echo "Test subdirs to run: $TEST_SUBDIRS"
|
||||||
|
|
||||||
# Collect all test files for the specified test types
|
if [[ -n "$TEST_SUBDIRS" ]]; then
|
||||||
TEST_FILES=""
|
# Collect all test files for the specified test types
|
||||||
for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do
|
TEST_FILES=""
|
||||||
# Skip certain test types for vllm provider
|
for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do
|
||||||
if [[ "$PROVIDER" == "vllm" ]]; then
|
if [[ -d "tests/integration/$test_subdir" ]]; then
|
||||||
if [[ "$test_subdir" == "safety" ]] || [[ "$test_subdir" == "post_training" ]] || [[ "$test_subdir" == "tool_runtime" ]]; then
|
# Find all Python test files in this directory
|
||||||
echo "Skipping $test_subdir for vllm provider"
|
test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py")
|
||||||
continue
|
if [[ -n "$test_files" ]]; then
|
||||||
|
TEST_FILES="$TEST_FILES $test_files"
|
||||||
|
echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
echo "Warning: Directory tests/integration/$test_subdir does not exist"
|
||||||
fi
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
if [[ -z "$TEST_FILES" ]]; then
|
||||||
|
echo "No test files found for the specified test types"
|
||||||
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "$STACK_CONFIG" != *"server:"* ]] && [[ "$test_subdir" == "batches" ]]; then
|
echo ""
|
||||||
echo "Skipping $test_subdir for library client until types are supported"
|
echo "=== Running all collected tests in a single pytest command ==="
|
||||||
continue
|
echo "Total test files: $(echo $TEST_FILES | wc -w)"
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ -d "tests/integration/$test_subdir" ]]; then
|
PYTEST_TARGET="$TEST_FILES"
|
||||||
# Find all Python test files in this directory
|
EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2"
|
||||||
test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py")
|
else
|
||||||
if [[ -n "$test_files" ]]; then
|
PYTEST_TARGET="tests/integration/"
|
||||||
TEST_FILES="$TEST_FILES $test_files"
|
EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
|
||||||
echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files"
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
echo "Warning: Directory tests/integration/$test_subdir does not exist"
|
|
||||||
fi
|
|
||||||
done
|
|
||||||
|
|
||||||
if [[ -z "$TEST_FILES" ]]; then
|
|
||||||
echo "No test files found for the specified test types"
|
|
||||||
exit 1
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
echo ""
|
|
||||||
echo "=== Running all collected tests in a single pytest command ==="
|
|
||||||
echo "Total test files: $(echo $TEST_FILES | wc -w)"
|
|
||||||
|
|
||||||
set +e
|
set +e
|
||||||
pytest -s -v $TEST_FILES \
|
pytest -s -v $PYTEST_TARGET \
|
||||||
--stack-config="$STACK_CONFIG" \
|
--stack-config="$STACK_CONFIG" \
|
||||||
-k "$PYTEST_PATTERN" \
|
-k "$PYTEST_PATTERN" \
|
||||||
--text-model="$TEXT_MODEL" \
|
$EXTRA_PARAMS \
|
||||||
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
|
--color=yes \
|
||||||
--color=yes $EXTRA_PARAMS \
|
|
||||||
--capture=tee-sys
|
--capture=tee-sys
|
||||||
exit_code=$?
|
exit_code=$?
|
||||||
set -e
|
set -e
|
||||||
|
@ -294,7 +263,13 @@ df -h
|
||||||
# stop server
|
# stop server
|
||||||
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
|
||||||
echo "Stopping Llama Stack Server..."
|
echo "Stopping Llama Stack Server..."
|
||||||
kill $(lsof -i :8321 | awk 'NR>1 {print $2}')
|
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
|
||||||
|
if [[ -n "$pids" ]]; then
|
||||||
|
echo "Killing Llama Stack Server processes: $pids"
|
||||||
|
kill -9 $pids
|
||||||
|
else
|
||||||
|
echo "No Llama Stack Server processes found ?!"
|
||||||
|
fi
|
||||||
echo "Llama Stack Server stopped"
|
echo "Llama Stack Server stopped"
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for
|
||||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
|
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
|
||||||
|
|
||||||
# Record with vision tests enabled
|
# Record with vision tests enabled
|
||||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests
|
./scripts/github/schedule-record-workflow.sh --test-suite vision
|
||||||
|
|
||||||
# Record with specific provider
|
# Record with specific provider
|
||||||
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm
|
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm
|
||||||
|
|
|
@ -42,6 +42,27 @@ Model parameters can be influenced by the following options:
|
||||||
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
|
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
|
||||||
if no model is specified.
|
if no model is specified.
|
||||||
|
|
||||||
|
### Suites (fast selection + sane defaults)
|
||||||
|
|
||||||
|
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
|
||||||
|
- Available suites:
|
||||||
|
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
|
||||||
|
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
|
||||||
|
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suite’s text model.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Fast responses run with defaults
|
||||||
|
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
|
||||||
|
|
||||||
|
# Fast single-file vision run with defaults
|
||||||
|
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
|
||||||
|
|
||||||
|
# Combine suites and override a default
|
||||||
|
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
|
||||||
|
```
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### Testing against a Server
|
### Testing against a Server
|
||||||
|
|
|
@ -268,3 +268,58 @@ class TestBatchesIntegration:
|
||||||
|
|
||||||
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||||
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
|
||||||
|
|
||||||
|
def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id):
|
||||||
|
"""Run an end-to-end batch with a single successful text completion request."""
|
||||||
|
request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20}
|
||||||
|
|
||||||
|
batch_requests = [
|
||||||
|
{
|
||||||
|
"custom_id": "success-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/completions",
|
||||||
|
"body": request_body,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
with batch_helper.create_file(batch_requests) as uploaded_file:
|
||||||
|
batch = openai_client.batches.create(
|
||||||
|
input_file_id=uploaded_file.id,
|
||||||
|
endpoint="/v1/completions",
|
||||||
|
completion_window="24h",
|
||||||
|
metadata={"test": "e2e_completions_success"},
|
||||||
|
)
|
||||||
|
|
||||||
|
final_batch = batch_helper.wait_for(
|
||||||
|
batch.id,
|
||||||
|
max_wait_time=3 * 60,
|
||||||
|
expected_statuses={"completed"},
|
||||||
|
timeout_action="skip",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert final_batch.status == "completed"
|
||||||
|
assert final_batch.request_counts is not None
|
||||||
|
assert final_batch.request_counts.total == 1
|
||||||
|
assert final_batch.request_counts.completed == 1
|
||||||
|
assert final_batch.output_file_id is not None
|
||||||
|
|
||||||
|
output_content = openai_client.files.content(final_batch.output_file_id)
|
||||||
|
if isinstance(output_content, str):
|
||||||
|
output_text = output_content
|
||||||
|
else:
|
||||||
|
output_text = output_content.content.decode("utf-8")
|
||||||
|
|
||||||
|
output_lines = output_text.strip().split("\n")
|
||||||
|
assert len(output_lines) == 1
|
||||||
|
|
||||||
|
result = json.loads(output_lines[0])
|
||||||
|
assert result["custom_id"] == "success-1"
|
||||||
|
assert "response" in result
|
||||||
|
assert result["response"]["status_code"] == 200
|
||||||
|
|
||||||
|
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
|
||||||
|
assert deleted_output_file.deleted
|
||||||
|
|
||||||
|
if final_batch.error_file_id is not None:
|
||||||
|
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
|
||||||
|
assert deleted_error_file.deleted
|
||||||
|
|
|
@ -6,15 +6,17 @@
|
||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import os
|
import os
|
||||||
import platform
|
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
from .suites import SUITE_DEFINITIONS
|
||||||
|
|
||||||
logger = get_logger(__name__, category="tests")
|
logger = get_logger(__name__, category="tests")
|
||||||
|
|
||||||
|
|
||||||
|
@ -61,9 +63,22 @@ def pytest_configure(config):
|
||||||
key, value = env_var.split("=", 1)
|
key, value = env_var.split("=", 1)
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
|
||||||
if platform.system() == "Darwin": # Darwin is the system name for macOS
|
suites_raw = config.getoption("--suite")
|
||||||
os.environ["DISABLE_CODE_SANDBOX"] = "1"
|
suites: list[str] = []
|
||||||
logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS")
|
if suites_raw:
|
||||||
|
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
||||||
|
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
|
||||||
|
if unknown:
|
||||||
|
raise pytest.UsageError(
|
||||||
|
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
|
||||||
|
)
|
||||||
|
for suite in suites:
|
||||||
|
suite_def = SUITE_DEFINITIONS.get(suite, {})
|
||||||
|
defaults: dict = suite_def.get("defaults", {})
|
||||||
|
for dest, value in defaults.items():
|
||||||
|
current = getattr(config.option, dest, None)
|
||||||
|
if not current:
|
||||||
|
setattr(config.option, dest, value)
|
||||||
|
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
def pytest_addoption(parser):
|
||||||
|
@ -105,16 +120,21 @@ def pytest_addoption(parser):
|
||||||
default=384,
|
default=384,
|
||||||
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
help="Output dimensionality of the embedding model to use for testing. Default: 384",
|
||||||
)
|
)
|
||||||
parser.addoption(
|
|
||||||
"--record-responses",
|
|
||||||
action="store_true",
|
|
||||||
help="Record new API responses instead of using cached ones.",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
parser.addoption(
|
||||||
"--report",
|
"--report",
|
||||||
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
|
||||||
|
suite_help = (
|
||||||
|
"Comma-separated integration test suites to narrow collection and prefill defaults. "
|
||||||
|
"Available: "
|
||||||
|
f"{available_suites}. "
|
||||||
|
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
|
||||||
|
"Examples: --suite=responses or --suite=responses,vision."
|
||||||
|
)
|
||||||
|
parser.addoption("--suite", help=suite_help)
|
||||||
|
|
||||||
|
|
||||||
MODEL_SHORT_IDS = {
|
MODEL_SHORT_IDS = {
|
||||||
"meta-llama/Llama-3.2-3B-Instruct": "3B",
|
"meta-llama/Llama-3.2-3B-Instruct": "3B",
|
||||||
|
@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc):
|
||||||
|
|
||||||
|
|
||||||
pytest_plugins = ["tests.integration.fixtures.common"]
|
pytest_plugins = ["tests.integration.fixtures.common"]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
|
||||||
|
"""Skip collecting paths outside the selected suite roots for speed."""
|
||||||
|
suites_raw = config.getoption("--suite")
|
||||||
|
if not suites_raw:
|
||||||
|
return False
|
||||||
|
|
||||||
|
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
|
||||||
|
roots: list[str] = []
|
||||||
|
for name in names:
|
||||||
|
suite_def = SUITE_DEFINITIONS.get(name)
|
||||||
|
if suite_def:
|
||||||
|
roots.extend(suite_def.get("roots", []))
|
||||||
|
if not roots:
|
||||||
|
return False
|
||||||
|
|
||||||
|
p = Path(str(path)).resolve()
|
||||||
|
|
||||||
|
# Only constrain within tests/integration to avoid ignoring unrelated tests
|
||||||
|
integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve()
|
||||||
|
if not p.is_relative_to(integration_root):
|
||||||
|
return False
|
||||||
|
|
||||||
|
for r in roots:
|
||||||
|
rp = (Path(str(config.rootpath)) / r).resolve()
|
||||||
|
if rp.is_file():
|
||||||
|
# Allow the exact file and any ancestor directories so pytest can walk into it.
|
||||||
|
if p == rp:
|
||||||
|
return False
|
||||||
|
if p.is_dir() and rp.is_relative_to(p):
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
# Allow anything inside an allowed directory
|
||||||
|
if p.is_relative_to(rp):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
|
@ -5,6 +5,8 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from ..test_cases.test_case import TestCase
|
from ..test_cases.test_case import TestCase
|
||||||
|
@ -35,6 +37,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
|
||||||
"remote::sambanova",
|
"remote::sambanova",
|
||||||
"remote::tgi",
|
"remote::tgi",
|
||||||
"remote::vertexai",
|
"remote::vertexai",
|
||||||
|
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
|
||||||
):
|
):
|
||||||
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
|
||||||
|
|
||||||
|
@ -56,6 +59,18 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
|
||||||
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
|
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_doesnt_support_n(client_with_models, model_id):
|
||||||
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
|
if provider.provider_type in (
|
||||||
|
"remote::sambanova",
|
||||||
|
"remote::ollama",
|
||||||
|
# Error code: 400 - [{'error': {'code': 400, 'message': 'Only one candidate can be specified in the
|
||||||
|
# current model', 'status': 'INVALID_ARGUMENT'}}]
|
||||||
|
"remote::gemini",
|
||||||
|
):
|
||||||
|
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
|
||||||
|
|
||||||
|
|
||||||
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
|
||||||
provider = provider_from_model(client_with_models, model_id)
|
provider = provider_from_model(client_with_models, model_id)
|
||||||
if provider.provider_type in (
|
if provider.provider_type in (
|
||||||
|
@ -260,10 +275,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
|
||||||
)
|
)
|
||||||
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
|
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
|
skip_if_doesnt_support_n(client_with_models, text_model_id)
|
||||||
provider = provider_from_model(client_with_models, text_model_id)
|
|
||||||
if provider.provider_type == "remote::ollama":
|
|
||||||
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
|
|
||||||
|
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
|
@ -323,8 +335,15 @@ def test_inference_store(compat_client, client_with_models, text_model_id, strea
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
responses = client.chat.completions.list(limit=1000)
|
tries = 0
|
||||||
assert response_id in [r.id for r in responses.data]
|
while tries < 10:
|
||||||
|
responses = client.chat.completions.list(limit=1000)
|
||||||
|
if response_id in [r.id for r in responses.data]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tries += 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert tries < 10, f"Response {response_id} not found after 1 second"
|
||||||
|
|
||||||
retrieved_response = client.chat.completions.retrieve(response_id)
|
retrieved_response = client.chat.completions.retrieve(response_id)
|
||||||
assert retrieved_response.id == response_id
|
assert retrieved_response.id == response_id
|
||||||
|
@ -388,6 +407,18 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
|
||||||
response_id = response.id
|
response_id = response.id
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
# wait for the response to be stored
|
||||||
|
tries = 0
|
||||||
|
while tries < 10:
|
||||||
|
responses = client.chat.completions.list(limit=1000)
|
||||||
|
if response_id in [r.id for r in responses.data]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
tries += 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
assert tries < 10, f"Response {response_id} not found after 1 second"
|
||||||
|
|
||||||
responses = client.chat.completions.list(limit=1000)
|
responses = client.chat.completions.list(limit=1000)
|
||||||
assert response_id in [r.id for r in responses.data]
|
assert response_id in [r.id for r in responses.data]
|
||||||
|
|
||||||
|
|
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
42
tests/integration/recordings/responses/41e27b9b5d09.json
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
{
|
||||||
|
"request": {
|
||||||
|
"method": "POST",
|
||||||
|
"url": "http://0.0.0.0:11434/v1/v1/completions",
|
||||||
|
"headers": {},
|
||||||
|
"body": {
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"prompt": "Say completions",
|
||||||
|
"max_tokens": 20
|
||||||
|
},
|
||||||
|
"endpoint": "/v1/completions",
|
||||||
|
"model": "llama3.2:3b-instruct-fp16"
|
||||||
|
},
|
||||||
|
"response": {
|
||||||
|
"body": {
|
||||||
|
"__type__": "openai.types.completion.Completion",
|
||||||
|
"__data__": {
|
||||||
|
"id": "cmpl-271",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"finish_reason": "length",
|
||||||
|
"index": 0,
|
||||||
|
"logprobs": null,
|
||||||
|
"text": "You want me to respond with a completion, but you didn't specify what I should complete. Could"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"created": 1756846620,
|
||||||
|
"model": "llama3.2:3b-instruct-fp16",
|
||||||
|
"object": "text_completion",
|
||||||
|
"system_fingerprint": "fp_ollama",
|
||||||
|
"usage": {
|
||||||
|
"completion_tokens": 20,
|
||||||
|
"prompt_tokens": 28,
|
||||||
|
"total_tokens": 48,
|
||||||
|
"completion_tokens_details": null,
|
||||||
|
"prompt_tokens_details": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"is_streaming": false
|
||||||
|
}
|
||||||
|
}
|
Before Width: | Height: | Size: 108 KiB After Width: | Height: | Size: 108 KiB |
Before Width: | Height: | Size: 148 KiB After Width: | Height: | Size: 148 KiB |
Before Width: | Height: | Size: 139 KiB After Width: | Height: | Size: 139 KiB |
53
tests/integration/suites.py
Normal file
53
tests/integration/suites.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest.
|
||||||
|
# For example:
|
||||||
|
#
|
||||||
|
# ```bash
|
||||||
|
# pytest tests/integration/ --suite=vision
|
||||||
|
# ```
|
||||||
|
#
|
||||||
|
# Each suite can:
|
||||||
|
# - restrict collection to specific roots (dirs or files)
|
||||||
|
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
this_dir = Path(__file__).parent
|
||||||
|
default_roots = [
|
||||||
|
str(p)
|
||||||
|
for p in this_dir.glob("*")
|
||||||
|
if p.is_dir()
|
||||||
|
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
|
||||||
|
]
|
||||||
|
|
||||||
|
SUITE_DEFINITIONS: dict[str, dict] = {
|
||||||
|
"base": {
|
||||||
|
"description": "Base suite that includes most tests but runs them with a text Ollama model",
|
||||||
|
"roots": default_roots,
|
||||||
|
"defaults": {
|
||||||
|
"text_model": "ollama/llama3.2:3b-instruct-fp16",
|
||||||
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"responses": {
|
||||||
|
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
|
||||||
|
"roots": ["tests/integration/responses"],
|
||||||
|
"defaults": {
|
||||||
|
"text_model": "openai/gpt-4o",
|
||||||
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"vision": {
|
||||||
|
"description": "Suite that includes only the vision tests",
|
||||||
|
"roots": ["tests/integration/inference/test_vision_inference.py"],
|
||||||
|
"defaults": {
|
||||||
|
"vision_model": "ollama/llama3.2-vision:11b",
|
||||||
|
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
|
@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models):
|
||||||
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
|
||||||
|
|
||||||
clear_registry()
|
clear_registry()
|
||||||
|
|
||||||
|
try:
|
||||||
|
client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
yield client_with_models
|
yield client_with_models
|
||||||
|
|
||||||
# you must clean after the last test if you were running tests against
|
|
||||||
# a stateful server instance
|
|
||||||
clear_registry()
|
clear_registry()
|
||||||
|
|
||||||
|
|
||||||
|
@ -66,12 +70,13 @@ def assert_valid_text_response(response):
|
||||||
def test_vector_db_insert_inline_and_query(
|
def test_vector_db_insert_inline_and_query(
|
||||||
client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
|
client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
|
||||||
):
|
):
|
||||||
vector_db_id = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
vector_db = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
vector_db_id = vector_db.identifier
|
||||||
|
|
||||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
documents=sample_documents,
|
documents=sample_documents,
|
||||||
|
@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query(
|
||||||
|
|
||||||
# list to check memory bank is successfully registered
|
# list to check memory bank is successfully registered
|
||||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
assert vector_db_id in available_vector_dbs
|
# VectorDB is being migrated to VectorStore, so the ID will be different
|
||||||
|
# Just check that at least one vector DB was registered
|
||||||
|
assert len(available_vector_dbs) > 0
|
||||||
|
# Use the actual registered vector_db_id for subsequent operations
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query(
|
||||||
|
|
||||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Query for the name of method
|
# Query for the name of method
|
||||||
response1 = client_with_empty_registry.vector_io.query(
|
response1 = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="What's the name of the fine-tunning method used?",
|
query="What's the name of the fine-tunning method used?",
|
||||||
)
|
)
|
||||||
assert_valid_chunk_response(response1)
|
assert_valid_chunk_response(response1)
|
||||||
|
@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query(
|
||||||
|
|
||||||
# Query for the name of model
|
# Query for the name of model
|
||||||
response2 = client_with_empty_registry.vector_io.query(
|
response2 = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="Which Llama model is mentioned?",
|
query="Which Llama model is mentioned?",
|
||||||
)
|
)
|
||||||
assert_valid_chunk_response(response2)
|
assert_valid_chunk_response(response2)
|
||||||
|
@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
||||||
)
|
)
|
||||||
|
|
||||||
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
assert vector_db_id in available_vector_dbs
|
# VectorDB is being migrated to VectorStore, so the ID will be different
|
||||||
|
# Just check that at least one vector DB was registered
|
||||||
|
assert len(available_vector_dbs) > 0
|
||||||
|
# Use the actual registered vector_db_id for subsequent operations
|
||||||
|
actual_vector_db_id = available_vector_dbs[0]
|
||||||
|
|
||||||
urls = [
|
urls = [
|
||||||
"memory_optimizations.rst",
|
"memory_optimizations.rst",
|
||||||
|
@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
||||||
|
|
||||||
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
client_with_empty_registry.tool_runtime.rag_tool.insert(
|
||||||
documents=documents,
|
documents=documents,
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
)
|
)
|
||||||
|
|
||||||
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
vector_db_ids=[vector_db_id],
|
vector_db_ids=[actual_vector_db_id],
|
||||||
content="What is the name of the method used for fine-tuning?",
|
content="What is the name of the method used for fine-tuning?",
|
||||||
)
|
)
|
||||||
assert_valid_text_response(response_with_metadata)
|
assert_valid_text_response(response_with_metadata)
|
||||||
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
|
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
|
||||||
|
|
||||||
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
vector_db_ids=[vector_db_id],
|
vector_db_ids=[actual_vector_db_id],
|
||||||
content="What is the name of the method used for fine-tuning?",
|
content="What is the name of the method used for fine-tuning?",
|
||||||
query_config={
|
query_config={
|
||||||
"include_metadata_in_content": True,
|
"include_metadata_in_content": True,
|
||||||
|
@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
|
||||||
|
|
||||||
with pytest.raises((ValueError, BadRequestError)):
|
with pytest.raises((ValueError, BadRequestError)):
|
||||||
client_with_empty_registry.tool_runtime.rag_tool.query(
|
client_with_empty_registry.tool_runtime.rag_tool.query(
|
||||||
vector_db_ids=[vector_db_id],
|
vector_db_ids=[actual_vector_db_id],
|
||||||
content="What is the name of the method used for fine-tuning?",
|
content="What is the name of the method used for fine-tuning?",
|
||||||
query_config={
|
query_config={
|
||||||
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
"chunk_template": "This should raise a ValueError because it is missing the proper template variables",
|
||||||
|
|
|
@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models):
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
# Register a memory bank first
|
vector_db_name = "test_vector_db"
|
||||||
vector_db_id = "test_vector_db"
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
client_with_empty_registry.vector_dbs.register(
|
vector_db_id=vector_db_name,
|
||||||
vector_db_id=vector_db_id,
|
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
# Retrieve the memory bank and validate its properties
|
# Retrieve the memory bank and validate its properties
|
||||||
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id)
|
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
assert response.identifier == vector_db_id
|
assert response.identifier == actual_vector_db_id
|
||||||
assert response.embedding_model == embedding_model_id
|
assert response.embedding_model == embedding_model_id
|
||||||
assert response.provider_resource_id == vector_db_id
|
assert response.identifier.startswith("vs_")
|
||||||
|
|
||||||
|
|
||||||
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
|
||||||
vector_db_id = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
actual_vector_db_id = response.identifier
|
||||||
assert vector_dbs_after_register == [vector_db_id]
|
assert actual_vector_db_id.startswith("vs_")
|
||||||
|
assert actual_vector_db_id != vector_db_name
|
||||||
|
|
||||||
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id)
|
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
|
assert vector_dbs_after_register == [actual_vector_db_id]
|
||||||
|
|
||||||
|
vector_stores = client_with_empty_registry.vector_stores.list()
|
||||||
|
assert len(vector_stores.data) == 1
|
||||||
|
vector_store = vector_stores.data[0]
|
||||||
|
assert vector_store.id == actual_vector_db_id
|
||||||
|
assert vector_store.name == vector_db_name
|
||||||
|
|
||||||
|
client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id)
|
||||||
|
|
||||||
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
|
||||||
assert len(vector_dbs) == 0
|
assert len(vector_dbs) == 0
|
||||||
|
@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
|
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
|
||||||
vector_db_id = "test_vector_db"
|
vector_db_name = "test_vector_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunks=sample_chunks,
|
chunks=sample_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="What is the capital of France?",
|
query="What is the capital of France?",
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
|
||||||
|
|
||||||
query, expected_doc_id = test_case
|
query, expected_doc_id = test_case
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query=query,
|
query=query,
|
||||||
)
|
)
|
||||||
assert response is not None
|
assert response is not None
|
||||||
|
@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
||||||
"remote::qdrant": {"score_threshold": -1.0},
|
"remote::qdrant": {"score_threshold": -1.0},
|
||||||
"inline::qdrant": {"score_threshold": -1.0},
|
"inline::qdrant": {"score_threshold": -1.0},
|
||||||
}
|
}
|
||||||
vector_db_id = "test_precomputed_embeddings_db"
|
vector_db_name = "test_precomputed_embeddings_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
chunks_with_embeddings = [
|
chunks_with_embeddings = [
|
||||||
Chunk(
|
Chunk(
|
||||||
content="This is a test chunk with precomputed embedding.",
|
content="This is a test chunk with precomputed embedding.",
|
||||||
|
@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
|
||||||
]
|
]
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunks=chunks_with_embeddings,
|
chunks=chunks_with_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="precomputed embedding test",
|
query="precomputed embedding test",
|
||||||
params=vector_io_provider_params_dict.get(provider, None),
|
params=vector_io_provider_params_dict.get(provider, None),
|
||||||
)
|
)
|
||||||
|
@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
"remote::qdrant": {"score_threshold": 0.0},
|
"remote::qdrant": {"score_threshold": 0.0},
|
||||||
"inline::qdrant": {"score_threshold": 0.0},
|
"inline::qdrant": {"score_threshold": 0.0},
|
||||||
}
|
}
|
||||||
vector_db_id = "test_precomputed_embeddings_db"
|
vector_db_name = "test_precomputed_embeddings_db"
|
||||||
client_with_empty_registry.vector_dbs.register(
|
register_response = client_with_empty_registry.vector_dbs.register(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_name,
|
||||||
embedding_model=embedding_model_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_dimension=embedding_dimension,
|
embedding_dimension=embedding_dimension,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
actual_vector_db_id = register_response.identifier
|
||||||
|
|
||||||
chunks_with_embeddings = [
|
chunks_with_embeddings = [
|
||||||
Chunk(
|
Chunk(
|
||||||
content="duplicate",
|
content="duplicate",
|
||||||
|
@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
|
||||||
]
|
]
|
||||||
|
|
||||||
client_with_empty_registry.vector_io.insert(
|
client_with_empty_registry.vector_io.insert(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
chunks=chunks_with_embeddings,
|
chunks=chunks_with_embeddings,
|
||||||
)
|
)
|
||||||
|
|
||||||
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
|
||||||
response = client_with_empty_registry.vector_io.query(
|
response = client_with_empty_registry.vector_io.query(
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=actual_vector_db_id,
|
||||||
query="duplicate",
|
query="duplicate",
|
||||||
params=vector_io_provider_params_dict.get(provider, None),
|
params=vector_io_provider_params_dict.get(provider, None),
|
||||||
)
|
)
|
||||||
|
|
|
@ -146,6 +146,20 @@ class VectorDBImpl(Impl):
|
||||||
async def unregister_vector_db(self, vector_db_id: str):
|
async def unregister_vector_db(self, vector_db_id: str):
|
||||||
return vector_db_id
|
return vector_db_id
|
||||||
|
|
||||||
|
async def openai_create_vector_store(self, **kwargs):
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
|
||||||
|
|
||||||
|
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
|
||||||
|
return VectorStoreObject(
|
||||||
|
id=vector_store_id,
|
||||||
|
name=kwargs.get("name", vector_store_id),
|
||||||
|
created_at=int(time.time()),
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_models_routing_table(cached_disk_dist_registry):
|
async def test_models_routing_table(cached_disk_dist_registry):
|
||||||
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register multiple vector databases and verify listing
|
# Register multiple vector databases and verify listing
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
|
||||||
assert len(vector_dbs.data) == 2
|
assert len(vector_dbs.data) == 2
|
||||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||||
assert "test-vectordb" in vector_db_ids
|
assert vdb1.identifier in vector_db_ids
|
||||||
assert "test-vectordb-2" in vector_db_ids
|
assert vdb2.identifier in vector_db_ids
|
||||||
|
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
# Verify they have UUID-based identifiers
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
assert vdb1.identifier.startswith("vs_")
|
||||||
|
assert vdb2.identifier.startswith("vs_")
|
||||||
|
|
||||||
|
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
|
||||||
|
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
|
||||||
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
assert len(vector_dbs.data) == 0
|
assert len(vector_dbs.data) == 0
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
# Unit tests for the routing tables vector_dbs
|
# Unit tests for the routing tables vector_dbs
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
import uuid
|
||||||
from unittest.mock import AsyncMock
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
|
||||||
class VectorDBImpl(Impl):
|
class VectorDBImpl(Impl):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__(Api.vector_io)
|
super().__init__(Api.vector_io)
|
||||||
|
self.vector_stores = {}
|
||||||
|
|
||||||
async def register_vector_db(self, vector_db: VectorDB):
|
async def register_vector_db(self, vector_db: VectorDB):
|
||||||
return vector_db
|
return vector_db
|
||||||
|
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
|
||||||
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
|
||||||
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
|
||||||
|
|
||||||
|
async def openai_create_vector_store(
|
||||||
|
self,
|
||||||
|
name=None,
|
||||||
|
embedding_model=None,
|
||||||
|
embedding_dimension=None,
|
||||||
|
provider_id=None,
|
||||||
|
provider_vector_db_id=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
|
||||||
|
vector_store = VectorStoreObject(
|
||||||
|
id=vector_store_id,
|
||||||
|
name=name or vector_store_id,
|
||||||
|
created_at=int(time.time()),
|
||||||
|
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
|
||||||
|
)
|
||||||
|
self.vector_stores[vector_store_id] = vector_store
|
||||||
|
return vector_store
|
||||||
|
|
||||||
|
async def openai_list_vector_stores(self, **kwargs):
|
||||||
|
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
|
||||||
|
|
||||||
|
return VectorStoreListResponse(
|
||||||
|
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
|
n = 10
|
||||||
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
|
||||||
await table.initialize()
|
await table.initialize()
|
||||||
|
|
||||||
|
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Register multiple vector databases and verify listing
|
# Register multiple vector databases and verify listing
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
|
vdb_dict = {}
|
||||||
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
|
for i in range(n):
|
||||||
|
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||||
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
|
||||||
assert len(vector_dbs.data) == 2
|
assert len(vector_dbs.data) == len(vdb_dict)
|
||||||
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||||
assert "test-vectordb" in vector_db_ids
|
for k in vdb_dict:
|
||||||
assert "test-vectordb-2" in vector_db_ids
|
assert vdb_dict[k].identifier in vector_db_ids
|
||||||
|
for k in vdb_dict:
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb")
|
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
|
||||||
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
|
|
||||||
|
|
||||||
vector_dbs = await table.list_vector_dbs()
|
vector_dbs = await table.list_vector_dbs()
|
||||||
assert len(vector_dbs.data) == 0
|
assert len(vector_dbs.data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
|
||||||
|
n = 10
|
||||||
|
impl = VectorDBImpl()
|
||||||
|
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await m_table.initialize()
|
||||||
|
await m_table.register_model(
|
||||||
|
model_id="test-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 128},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
vdb_dict = {}
|
||||||
|
for i in range(n):
|
||||||
|
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
|
||||||
|
|
||||||
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
vector_db_ids = {v.identifier for v in vector_dbs.data}
|
||||||
|
|
||||||
|
vector_stores = await impl.openai_list_vector_stores()
|
||||||
|
vector_store_ids = {v.id for v in vector_stores.data}
|
||||||
|
|
||||||
|
assert vector_db_ids == vector_store_ids, (
|
||||||
|
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for vector_store in vector_stores.data:
|
||||||
|
vector_db = await table.get_vector_db(vector_store.id)
|
||||||
|
assert vector_store.name == vector_db.vector_db_name, (
|
||||||
|
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for vector_db_id in vector_db_ids:
|
||||||
|
await table.unregister_vector_db(vector_db_id)
|
||||||
|
|
||||||
|
assert len((await table.list_vector_dbs()).data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
|
||||||
|
impl = VectorDBImpl()
|
||||||
|
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
|
||||||
|
await table.initialize()
|
||||||
|
|
||||||
|
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
|
||||||
|
await m_table.initialize()
|
||||||
|
await m_table.register_model(
|
||||||
|
model_id="test-model",
|
||||||
|
provider_id="test_provider",
|
||||||
|
metadata={"embedding_dimension": 128},
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
user_provided_id = "my-custom-vector-db"
|
||||||
|
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
|
||||||
|
|
||||||
|
vector_stores = await impl.openai_list_vector_stores()
|
||||||
|
assert len(vector_stores.data) == 1
|
||||||
|
|
||||||
|
vector_store = vector_stores.data[0]
|
||||||
|
|
||||||
|
assert vector_store.name == user_provided_id
|
||||||
|
|
||||||
|
assert vector_store.id.startswith("vs_")
|
||||||
|
assert vector_store.id != user_provided_id
|
||||||
|
|
||||||
|
vector_dbs = await table.list_vector_dbs()
|
||||||
|
assert len(vector_dbs.data) == 1
|
||||||
|
assert vector_dbs.data[0].identifier == vector_store.id
|
||||||
|
|
||||||
|
await table.unregister_vector_db(vector_store.id)
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
|
||||||
impl = VectorDBImpl()
|
impl = VectorDBImpl()
|
||||||
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
|
||||||
|
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
|
||||||
|
|
||||||
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
|
||||||
with request_provider_data_context({}, authorized_user):
|
with request_provider_data_context({}, authorized_user):
|
||||||
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
|
||||||
|
authorized_table = registered_vdb.identifier # Use the actual generated ID
|
||||||
|
|
||||||
# Authorized reader
|
# Authorized reader
|
||||||
with request_provider_data_context({}, authorized_user):
|
with request_provider_data_context({}, authorized_user):
|
||||||
|
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
|
||||||
)
|
)
|
||||||
|
|
||||||
with request_provider_data_context({}, admin_user):
|
with request_provider_data_context({}, admin_user):
|
||||||
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
|
||||||
|
vector_db_id = registered_vdb.identifier # Use the actual generated ID
|
||||||
|
|
||||||
read_methods = [
|
read_methods = [
|
||||||
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
(table.openai_retrieve_vector_store, (vector_db_id,), {}),
|
||||||
|
|
|
@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
|
||||||
* test_validate_input_url_mismatch (negative)
|
* test_validate_input_url_mismatch (negative)
|
||||||
* test_validate_input_multiple_errors_per_request (negative)
|
* test_validate_input_multiple_errors_per_request (negative)
|
||||||
* test_validate_input_invalid_request_format (negative)
|
* test_validate_input_invalid_request_format (negative)
|
||||||
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
|
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
|
||||||
|
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
|
||||||
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
|
||||||
|
|
||||||
The tests use temporary SQLite databases for isolation and mock external
|
The tests use temporary SQLite databases for isolation and mock external
|
||||||
|
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
|
||||||
"endpoint",
|
"endpoint",
|
||||||
[
|
[
|
||||||
"/v1/embeddings",
|
"/v1/embeddings",
|
||||||
"/v1/completions",
|
|
||||||
"/v1/invalid/endpoint",
|
"/v1/invalid/endpoint",
|
||||||
"",
|
"",
|
||||||
],
|
],
|
||||||
|
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
|
||||||
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
|
async def test_validate_input_missing_parameters_chat_completions(
|
||||||
"""Test _validate_input when file contains request with missing required parameters."""
|
self, provider, param_name, param_path, error_code, error_message
|
||||||
|
):
|
||||||
|
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
|
||||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
|
||||||
assert errors[0].message == error_message
|
assert errors[0].message == error_message
|
||||||
assert errors[0].param == param_path
|
assert errors[0].param == param_path
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"param_name,param_path,error_code,error_message",
|
||||||
|
[
|
||||||
|
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
|
||||||
|
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
|
||||||
|
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
|
||||||
|
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
|
||||||
|
("model", "body.model", "invalid_request", "Model parameter is required"),
|
||||||
|
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
async def test_validate_input_missing_parameters_completions(
|
||||||
|
self, provider, param_name, param_path, error_code, error_message
|
||||||
|
):
|
||||||
|
"""Test _validate_input when file contains request with missing required parameters for text completions."""
|
||||||
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
|
||||||
|
base_request = {
|
||||||
|
"custom_id": "req-1",
|
||||||
|
"method": "POST",
|
||||||
|
"url": "/v1/completions",
|
||||||
|
"body": {"model": "test-model", "prompt": "Hello"},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Remove the specific parameter being tested
|
||||||
|
if "." in param_path:
|
||||||
|
top_level, nested_param = param_path.split(".", 1)
|
||||||
|
del base_request[top_level][nested_param]
|
||||||
|
else:
|
||||||
|
del base_request[param_name]
|
||||||
|
|
||||||
|
mock_response.body = json.dumps(base_request).encode()
|
||||||
|
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
batch = BatchObject(
|
||||||
|
id="batch_test",
|
||||||
|
object="batch",
|
||||||
|
endpoint="/v1/completions",
|
||||||
|
input_file_id=f"missing_{param_name}_file",
|
||||||
|
completion_window="24h",
|
||||||
|
status="validating",
|
||||||
|
created_at=1234567890,
|
||||||
|
)
|
||||||
|
|
||||||
|
errors, requests = await provider._validate_input(batch)
|
||||||
|
|
||||||
|
assert len(errors) == 1
|
||||||
|
assert len(requests) == 0
|
||||||
|
|
||||||
|
assert errors[0].code == error_code
|
||||||
|
assert errors[0].line == 1
|
||||||
|
assert errors[0].message == error_message
|
||||||
|
assert errors[0].param == param_path
|
||||||
|
|
||||||
async def test_validate_input_url_mismatch(self, provider):
|
async def test_validate_input_url_mismatch(self, provider):
|
||||||
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
|
||||||
provider.files_api.openai_retrieve_file = AsyncMock()
|
provider.files_api.openai_retrieve_file = AsyncMock()
|
||||||
|
|
63
tests/unit/providers/inference/bedrock/test_config.py
Normal file
63
tests/unit/providers/inference/bedrock/test_config.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockBaseConfig:
|
||||||
|
def test_defaults_work_without_env_vars(self):
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
config = BedrockBaseConfig()
|
||||||
|
|
||||||
|
# Basic creds should be None
|
||||||
|
assert config.aws_access_key_id is None
|
||||||
|
assert config.aws_secret_access_key is None
|
||||||
|
assert config.region_name is None
|
||||||
|
|
||||||
|
# Timeouts get defaults
|
||||||
|
assert config.connect_timeout == 60.0
|
||||||
|
assert config.read_timeout == 60.0
|
||||||
|
assert config.session_ttl == 3600
|
||||||
|
|
||||||
|
def test_env_vars_get_picked_up(self):
|
||||||
|
env_vars = {
|
||||||
|
"AWS_ACCESS_KEY_ID": "AKIATEST123",
|
||||||
|
"AWS_SECRET_ACCESS_KEY": "secret123",
|
||||||
|
"AWS_DEFAULT_REGION": "us-west-2",
|
||||||
|
"AWS_MAX_ATTEMPTS": "5",
|
||||||
|
"AWS_RETRY_MODE": "adaptive",
|
||||||
|
"AWS_CONNECT_TIMEOUT": "30",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.dict(os.environ, env_vars, clear=True):
|
||||||
|
config = BedrockBaseConfig()
|
||||||
|
|
||||||
|
assert config.aws_access_key_id == "AKIATEST123"
|
||||||
|
assert config.aws_secret_access_key == "secret123"
|
||||||
|
assert config.region_name == "us-west-2"
|
||||||
|
assert config.total_max_attempts == 5
|
||||||
|
assert config.retry_mode == "adaptive"
|
||||||
|
assert config.connect_timeout == 30.0
|
||||||
|
|
||||||
|
def test_partial_env_setup(self):
|
||||||
|
# Just setting one timeout var
|
||||||
|
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
|
||||||
|
config = BedrockBaseConfig()
|
||||||
|
|
||||||
|
assert config.connect_timeout == 120.0
|
||||||
|
assert config.read_timeout == 60.0 # still default
|
||||||
|
assert config.aws_access_key_id is None
|
||||||
|
|
||||||
|
def test_bad_max_attempts_breaks(self):
|
||||||
|
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
|
||||||
|
try:
|
||||||
|
BedrockBaseConfig()
|
||||||
|
raise AssertionError("Should have failed on bad int conversion")
|
||||||
|
except ValueError:
|
||||||
|
pass # expected
|
|
@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
|
||||||
|
|
||||||
class TestRagQuery:
|
class TestRagQuery:
|
||||||
async def test_query_raises_on_empty_vector_db_ids(self):
|
async def test_query_raises_on_empty_vector_db_ids(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
|
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||||
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||||
|
|
||||||
async def test_query_chunk_metadata_handling(self):
|
async def test_query_chunk_metadata_handling(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock())
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
|
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||||
|
)
|
||||||
content = "test query content"
|
content = "test query content"
|
||||||
vector_db_ids = ["db1"]
|
vector_db_ids = ["db1"]
|
||||||
|
|
||||||
|
|
|
@ -113,6 +113,15 @@ class TestTranslateException:
|
||||||
assert result.status_code == 504
|
assert result.status_code == 504
|
||||||
assert result.detail == "Operation timed out: "
|
assert result.detail == "Operation timed out: "
|
||||||
|
|
||||||
|
def test_translate_connection_error(self):
|
||||||
|
"""Test that ConnectionError is translated to 502 HTTP status."""
|
||||||
|
exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused")
|
||||||
|
result = translate_exception(exc)
|
||||||
|
|
||||||
|
assert isinstance(result, HTTPException)
|
||||||
|
assert result.status_code == 502
|
||||||
|
assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"
|
||||||
|
|
||||||
def test_translate_not_implemented_error(self):
|
def test_translate_not_implemented_error(self):
|
||||||
"""Test that NotImplementedError is translated to 501 HTTP status."""
|
"""Test that NotImplementedError is translated to 501 HTTP status."""
|
||||||
exc = NotImplementedError("Not implemented")
|
exc = NotImplementedError("Not implemented")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue