Merge branch 'main' into prompt-api

This commit is contained in:
Francisco Arceo 2025-09-05 15:11:00 -06:00 committed by GitHub
commit 5d610de5db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 679 additions and 240 deletions

View file

@ -2,13 +2,6 @@ name: 'Run and Record Tests'
description: 'Run integration tests and handle recording/artifact upload' description: 'Run integration tests and handle recording/artifact upload'
inputs: inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
required: true
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
stack-config: stack-config:
description: 'Stack configuration to use' description: 'Stack configuration to use'
required: true required: true
@ -18,10 +11,18 @@ inputs:
inference-mode: inference-mode:
description: 'Inference mode (record or replay)' description: 'Inference mode (record or replay)'
required: true required: true
run-vision-tests: test-suite:
description: 'Whether to run vision tests' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
required: false
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
runs: runs:
using: 'composite' using: 'composite'
@ -42,7 +43,7 @@ runs:
--test-subdirs '${{ inputs.test-subdirs }}' \ --test-subdirs '${{ inputs.test-subdirs }}' \
--test-pattern '${{ inputs.test-pattern }}' \ --test-pattern '${{ inputs.test-pattern }}' \
--inference-mode '${{ inputs.inference-mode }}' \ --inference-mode '${{ inputs.inference-mode }}' \
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ --test-suite '${{ inputs.test-suite }}' \
| tee pytest-${{ inputs.inference-mode }}.log | tee pytest-${{ inputs.inference-mode }}.log
@ -57,12 +58,7 @@ runs:
echo "New recordings detected, committing and pushing" echo "New recordings detected, committing and pushing"
git add tests/integration/recordings/ git add tests/integration/recordings/
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})"
git commit -m "Recordings update from CI (vision)"
else
git commit -m "Recordings update from CI"
fi
git fetch origin ${{ github.ref_name }} git fetch origin ${{ github.ref_name }}
git rebase origin/${{ github.ref_name }} git rebase origin/${{ github.ref_name }}
echo "Rebased successfully" echo "Rebased successfully"

View file

@ -1,17 +1,17 @@
name: Setup Ollama name: Setup Ollama
description: Start Ollama description: Start Ollama
inputs: inputs:
run-vision-tests: test-suite:
description: 'Run vision tests: "true" or "false"' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Start Ollama - name: Start Ollama
shell: bash shell: bash
run: | run: |
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then if [ "${{ inputs.test-suite }}" == "vision" ]; then
image="ollama-with-vision-model" image="ollama-with-vision-model"
else else
image="ollama-with-models" image="ollama-with-models"

View file

@ -12,10 +12,10 @@ inputs:
description: 'Provider to setup (ollama or vllm)' description: 'Provider to setup (ollama or vllm)'
required: true required: true
default: 'ollama' default: 'ollama'
run-vision-tests: test-suite:
description: 'Whether to setup provider for vision tests' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
inference-mode: inference-mode:
description: 'Inference mode (record or replay)' description: 'Inference mode (record or replay)'
required: true required: true
@ -33,7 +33,7 @@ runs:
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-ollama uses: ./.github/actions/setup-ollama
with: with:
run-vision-tests: ${{ inputs.run-vision-tests }} test-suite: ${{ inputs.test-suite }}
- name: Setup vllm - name: Setup vllm
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}

View file

@ -8,7 +8,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode | | Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |

View file

@ -1,6 +1,6 @@
name: Integration Tests (Replay) name: Integration Tests (Replay)
run-name: Run the integration test suite from tests/integration in replay mode run-name: Run the integration test suites from tests/integration in replay mode
on: on:
push: push:
@ -32,14 +32,6 @@ on:
description: 'Test against a specific provider' description: 'Test against a specific provider'
type: string type: string
default: 'ollama' default: 'ollama'
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
default: ''
concurrency: concurrency:
# Skip concurrency for pushes to main - each commit should be tested independently # Skip concurrency for pushes to main - each commit should be tested independently
@ -50,7 +42,7 @@ jobs:
run-replay-mode-tests: run-replay-mode-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }}
strategy: strategy:
fail-fast: false fail-fast: false
@ -61,7 +53,7 @@ jobs:
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
run-vision-tests: [true, false] test-suite: [base, vision]
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -73,15 +65,13 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }} client-version: ${{ matrix.client-version }}
provider: ${{ matrix.provider }} provider: ${{ matrix.provider }}
run-vision-tests: ${{ matrix.run-vision-tests }} test-suite: ${{ matrix.test-suite }}
inference-mode: 'replay' inference-mode: 'replay'
- name: Run tests - name: Run tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }} provider: ${{ matrix.provider }}
inference-mode: 'replay' inference-mode: 'replay'
run-vision-tests: ${{ matrix.run-vision-tests }} test-suite: ${{ matrix.test-suite }}

View file

@ -10,18 +10,18 @@ run-name: Run the integration test suite from tests/integration
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-provider: test-provider:
description: 'Test against a specific provider' description: 'Test against a specific provider'
type: string type: string
default: 'ollama' default: 'ollama'
run-vision-tests: test-suite:
description: 'Whether to run vision tests' description: 'Test suite to use: base, responses, vision, etc.'
type: boolean type: string
default: false default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
type: string
default: ''
test-pattern: test-pattern:
description: 'Regex pattern to pass to pytest -k' description: 'Regex pattern to pass to pytest -k'
type: string type: string
@ -38,11 +38,11 @@ jobs:
- name: Echo workflow inputs - name: Echo workflow inputs
run: | run: |
echo "::group::Workflow Inputs" echo "::group::Workflow Inputs"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "branch: ${{ github.ref_name }}" echo "branch: ${{ github.ref_name }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "test-suite: ${{ inputs.test-suite }}"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "::endgroup::" echo "::endgroup::"
- name: Checkout repository - name: Checkout repository
@ -56,15 +56,15 @@ jobs:
python-version: "3.12" # Use single Python version for recording python-version: "3.12" # Use single Python version for recording
client-version: "latest" client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ inputs.test-provider || 'ollama' }}
run-vision-tests: ${{ inputs.run-vision-tests }} test-suite: ${{ inputs.test-suite }}
inference-mode: 'record' inference-mode: 'record'
- name: Run and record tests - name: Run and record tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-pattern: ${{ inputs.test-pattern }}
test-subdirs: ${{ inputs.test-subdirs }}
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ inputs.test-provider || 'ollama' }}
inference-mode: 'record' inference-mode: 'record'
run-vision-tests: ${{ inputs.run-vision-tests }} test-suite: ${{ inputs.test-suite }}
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}

View file

@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
vector_db_name: str | None = None, vector_db_name: str | None = None,
) -> VectorDB: ) -> VectorDB:
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None: if provider_id is None:
if len(self.impls_by_provider_id) > 0: if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension") raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_store_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
"provider_id": provider_id, "provider_id": provider_id,
"provider_resource_id": provider_vector_db_id, "provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name, "vector_db_name": vector_store.name,
} }
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection # TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]: if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
raise ValueError( raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
) )
if completion_window != "24h": if completion_window != "24h":
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
) )
valid = False valid = False
for param, expected_type, type_string in [ if batch.endpoint == "/v1/chat/completions":
("model", str, "a string"), required_params = [
# messages is specific to /v1/chat/completions ("model", str, "a string"),
# we could skip validating messages here and let inference fail. however, # messages is specific to /v1/chat/completions
# that would be a very expensive way to find out messages is wrong. # we could skip validating messages here and let inference fail. however,
("messages", list, "an array"), # TODO: allow messages to be a string? # that would be a very expensive way to find out messages is wrong.
]: ("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
for param, expected_type, type_string in required_params:
if param not in body: if param not in body:
errors.append( errors.append(
BatchError( BatchError(
@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches):
try: try:
# TODO(SECURITY): review body for security issues # TODO(SECURITY): review body for security issues
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] if request.url == "/v1/chat/completions":
chat_response = await self.inference_api.openai_chat_completion(**request.body) request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type # this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return { return {
"id": request_id, "id": request_id,
"custom_id": request.custom_id, "custom_id": request.custom_id,
"response": { "response": {
"status_code": 200, "status_code": 200,
"request_id": request_id, # TODO: should this be different? "request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(), "body": chat_response.model_dump_json(),
}, },
} }
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
"Completion response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id,
"body": completion_response.model_dump_json(),
},
}
except Exception as e: except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return { return {

View file

@ -15,7 +15,7 @@ set -euo pipefail
BRANCH="" BRANCH=""
TEST_SUBDIRS="" TEST_SUBDIRS=""
TEST_PROVIDER="ollama" TEST_PROVIDER="ollama"
RUN_VISION_TESTS=false TEST_SUITE="base"
TEST_PATTERN="" TEST_PATTERN=""
# Help function # Help function
@ -27,9 +27,9 @@ Trigger the integration test recording workflow remotely. This way you do not ne
OPTIONS: OPTIONS:
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch) -b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (REQUIRED)
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama) -p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama)
-v, --run-vision-tests Include vision tests in the recording -t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k -k, --test-pattern PATTERN Regex pattern to pass to pytest -k
-h, --help Show this help message -h, --help Show this help message
@ -38,7 +38,7 @@ EXAMPLES:
$0 --test-subdirs "agents" $0 --test-subdirs "agents"
# Record tests for specific branch with vision tests # Record tests for specific branch with vision tests
$0 -b my-feature-branch --test-subdirs "inference" --run-vision-tests $0 -b my-feature-branch --test-suite vision
# Record multiple test subdirectories with specific provider # Record multiple test subdirectories with specific provider
$0 --test-subdirs "agents,inference" --test-provider vllm $0 --test-subdirs "agents,inference" --test-provider vllm
@ -71,9 +71,9 @@ while [[ $# -gt 0 ]]; do
TEST_PROVIDER="$2" TEST_PROVIDER="$2"
shift 2 shift 2
;; ;;
-v|--run-vision-tests) -t|--test-suite)
RUN_VISION_TESTS=true TEST_SUITE="$2"
shift shift 2
;; ;;
-k|--test-pattern) -k|--test-pattern)
TEST_PATTERN="$2" TEST_PATTERN="$2"
@ -92,11 +92,11 @@ while [[ $# -gt 0 ]]; do
done done
# Validate required parameters # Validate required parameters
if [[ -z "$TEST_SUBDIRS" ]]; then if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
echo "Error: --test-subdirs is required" echo "Error: --test-subdirs or --test-suite is required"
echo "Please specify which test subdirectories to run, e.g.:" echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
echo " $0 --test-subdirs \"agents,inference\"" echo " $0 --test-subdirs \"agents,inference\""
echo " $0 --test-subdirs \"inference\" --run-vision-tests" echo " $0 --test-suite vision"
echo "" echo ""
exit 1 exit 1
fi fi
@ -239,17 +239,19 @@ echo "Triggering integration test recording workflow..."
echo "Branch: $BRANCH" echo "Branch: $BRANCH"
echo "Test provider: $TEST_PROVIDER" echo "Test provider: $TEST_PROVIDER"
echo "Test subdirs: $TEST_SUBDIRS" echo "Test subdirs: $TEST_SUBDIRS"
echo "Run vision tests: $RUN_VISION_TESTS" echo "Test suite: $TEST_SUITE"
echo "Test pattern: ${TEST_PATTERN:-"(none)"}" echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
echo "" echo ""
# Prepare inputs for gh workflow run # Prepare inputs for gh workflow run
INPUTS="-f test-subdirs='$TEST_SUBDIRS'" if [[ -n "$TEST_SUBDIRS" ]]; then
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
fi
if [[ -n "$TEST_PROVIDER" ]]; then if [[ -n "$TEST_PROVIDER" ]]; then
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'" INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'"
fi fi
if [[ "$RUN_VISION_TESTS" == "true" ]]; then if [[ -n "$TEST_SUITE" ]]; then
INPUTS="$INPUTS -f run-vision-tests=true" INPUTS="$INPUTS -f test-suite='$TEST_SUITE'"
fi fi
if [[ -n "$TEST_PATTERN" ]]; then if [[ -n "$TEST_PATTERN" ]]; then
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'" INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'"

View file

@ -16,7 +16,7 @@ STACK_CONFIG=""
PROVIDER="" PROVIDER=""
TEST_SUBDIRS="" TEST_SUBDIRS=""
TEST_PATTERN="" TEST_PATTERN=""
RUN_VISION_TESTS="false" TEST_SUITE="base"
INFERENCE_MODE="replay" INFERENCE_MODE="replay"
EXTRA_PARAMS="" EXTRA_PARAMS=""
@ -28,12 +28,16 @@ Usage: $0 [OPTIONS]
Options: Options:
--stack-config STRING Stack configuration to use (required) --stack-config STRING Stack configuration to use (required)
--provider STRING Provider to use (ollama, vllm, etc.) (required) --provider STRING Provider to use (ollama, vllm, etc.) (required)
--test-subdirs STRING Comma-separated list of test subdirectories to run (default: 'inference') --test-suite STRING Comma-separated list of test suites to run (default: 'base')
--run-vision-tests Run vision tests instead of regular tests
--inference-mode STRING Inference mode: record or replay (default: replay) --inference-mode STRING Inference mode: record or replay (default: replay)
--test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--test-pattern STRING Regex pattern to pass to pytest -k --test-pattern STRING Regex pattern to pass to pytest -k
--help Show this help message --help Show this help message
Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options.
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
Examples: Examples:
# Basic inference tests with ollama # Basic inference tests with ollama
$0 --stack-config server:ci-tests --provider ollama $0 --stack-config server:ci-tests --provider ollama
@ -42,7 +46,7 @@ Examples:
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents' $0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents'
# Vision tests with ollama # Vision tests with ollama
$0 --stack-config server:ci-tests --provider ollama --run-vision-tests $0 --stack-config server:ci-tests --provider ollama --test-suite vision
# Record mode for updating test recordings # Record mode for updating test recordings
$0 --stack-config server:ci-tests --provider ollama --inference-mode record $0 --stack-config server:ci-tests --provider ollama --inference-mode record
@ -64,9 +68,9 @@ while [[ $# -gt 0 ]]; do
TEST_SUBDIRS="$2" TEST_SUBDIRS="$2"
shift 2 shift 2
;; ;;
--run-vision-tests) --test-suite)
RUN_VISION_TESTS="true" TEST_SUITE="$2"
shift shift 2
;; ;;
--inference-mode) --inference-mode)
INFERENCE_MODE="$2" INFERENCE_MODE="$2"
@ -92,22 +96,25 @@ done
# Validate required parameters # Validate required parameters
if [[ -z "$STACK_CONFIG" ]]; then if [[ -z "$STACK_CONFIG" ]]; then
echo "Error: --stack-config is required" echo "Error: --stack-config is required"
usage
exit 1 exit 1
fi fi
if [[ -z "$PROVIDER" ]]; then if [[ -z "$PROVIDER" ]]; then
echo "Error: --provider is required" echo "Error: --provider is required"
usage exit 1
fi
if [[ -z "$TEST_SUITE" && -z "$TEST_SUBDIRS" ]]; then
echo "Error: --test-suite or --test-subdirs is required"
exit 1 exit 1
fi fi
echo "=== Llama Stack Integration Test Runner ===" echo "=== Llama Stack Integration Test Runner ==="
echo "Stack Config: $STACK_CONFIG" echo "Stack Config: $STACK_CONFIG"
echo "Provider: $PROVIDER" echo "Provider: $PROVIDER"
echo "Test Subdirs: $TEST_SUBDIRS"
echo "Vision Tests: $RUN_VISION_TESTS"
echo "Inference Mode: $INFERENCE_MODE" echo "Inference Mode: $INFERENCE_MODE"
echo "Test Suite: $TEST_SUITE"
echo "Test Subdirs: $TEST_SUBDIRS"
echo "Test Pattern: $TEST_PATTERN" echo "Test Pattern: $TEST_PATTERN"
echo "" echo ""
@ -194,84 +201,46 @@ if [[ -n "$TEST_PATTERN" ]]; then
PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN" PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN"
fi fi
# Run vision tests if specified
if [[ "$RUN_VISION_TESTS" == "true" ]]; then
echo "Running vision tests..."
set +e
pytest -s -v tests/integration/inference/test_vision_inference.py \
--stack-config="$STACK_CONFIG" \
-k "$PYTEST_PATTERN" \
--vision-model=ollama/llama3.2-vision:11b \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
--color=yes $EXTRA_PARAMS \
--capture=tee-sys
exit_code=$?
set -e
if [ $exit_code -eq 0 ]; then
echo "✅ Vision tests completed successfully"
elif [ $exit_code -eq 5 ]; then
echo "⚠️ No vision tests collected (pattern matched no tests)"
else
echo "❌ Vision tests failed"
exit 1
fi
exit 0
fi
# Run regular tests
if [[ -z "$TEST_SUBDIRS" ]]; then
TEST_SUBDIRS=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
sed 's|tests/integration/||' |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort)
fi
echo "Test subdirs to run: $TEST_SUBDIRS" echo "Test subdirs to run: $TEST_SUBDIRS"
# Collect all test files for the specified test types if [[ -n "$TEST_SUBDIRS" ]]; then
TEST_FILES="" # Collect all test files for the specified test types
for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do TEST_FILES=""
# Skip certain test types for vllm provider for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do
if [[ "$PROVIDER" == "vllm" ]]; then if [[ -d "tests/integration/$test_subdir" ]]; then
if [[ "$test_subdir" == "safety" ]] || [[ "$test_subdir" == "post_training" ]] || [[ "$test_subdir" == "tool_runtime" ]]; then # Find all Python test files in this directory
echo "Skipping $test_subdir for vllm provider" test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py")
continue if [[ -n "$test_files" ]]; then
TEST_FILES="$TEST_FILES $test_files"
echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files"
fi
else
echo "Warning: Directory tests/integration/$test_subdir does not exist"
fi fi
done
if [[ -z "$TEST_FILES" ]]; then
echo "No test files found for the specified test types"
exit 1
fi fi
if [[ "$STACK_CONFIG" != *"server:"* ]] && [[ "$test_subdir" == "batches" ]]; then echo ""
echo "Skipping $test_subdir for library client until types are supported" echo "=== Running all collected tests in a single pytest command ==="
continue echo "Total test files: $(echo $TEST_FILES | wc -w)"
fi
if [[ -d "tests/integration/$test_subdir" ]]; then PYTEST_TARGET="$TEST_FILES"
# Find all Python test files in this directory EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2"
test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") else
if [[ -n "$test_files" ]]; then PYTEST_TARGET="tests/integration/"
TEST_FILES="$TEST_FILES $test_files" EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files"
fi
else
echo "Warning: Directory tests/integration/$test_subdir does not exist"
fi
done
if [[ -z "$TEST_FILES" ]]; then
echo "No test files found for the specified test types"
exit 1
fi fi
echo ""
echo "=== Running all collected tests in a single pytest command ==="
echo "Total test files: $(echo $TEST_FILES | wc -w)"
set +e set +e
pytest -s -v $TEST_FILES \ pytest -s -v $PYTEST_TARGET \
--stack-config="$STACK_CONFIG" \ --stack-config="$STACK_CONFIG" \
-k "$PYTEST_PATTERN" \ -k "$PYTEST_PATTERN" \
--text-model="$TEXT_MODEL" \ $EXTRA_PARAMS \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ --color=yes \
--color=yes $EXTRA_PARAMS \
--capture=tee-sys --capture=tee-sys
exit_code=$? exit_code=$?
set -e set -e
@ -294,7 +263,13 @@ df -h
# stop server # stop server
if [[ "$STACK_CONFIG" == *"server:"* ]]; then if [[ "$STACK_CONFIG" == *"server:"* ]]; then
echo "Stopping Llama Stack Server..." echo "Stopping Llama Stack Server..."
kill $(lsof -i :8321 | awk 'NR>1 {print $2}') pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
else
echo "No Llama Stack Server processes found ?!"
fi
echo "Llama Stack Server stopped" echo "Llama Stack Server stopped"
fi fi

View file

@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference" ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
# Record with vision tests enabled # Record with vision tests enabled
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests ./scripts/github/schedule-record-workflow.sh --test-suite vision
# Record with specific provider # Record with specific provider
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm

View file

@ -42,6 +42,27 @@ Model parameters can be influenced by the following options:
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
if no model is specified. if no model is specified.
### Suites (fast selection + sane defaults)
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
- Available suites:
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suites text model.
Examples:
```bash
# Fast responses run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
# Fast single-file vision run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
# Combine suites and override a default
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
```
## Examples ## Examples
### Testing against a Server ### Testing against a Server

View file

@ -268,3 +268,58 @@ class TestBatchesIntegration:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id) deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id):
"""Run an end-to-end batch with a single successful text completion request."""
request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20}
batch_requests = [
{
"custom_id": "success-1",
"method": "POST",
"url": "/v1/completions",
"body": request_body,
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/completions",
completion_window="24h",
metadata={"test": "e2e_completions_success"},
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60,
expected_statuses={"completed"},
timeout_action="skip",
)
assert final_batch.status == "completed"
assert final_batch.request_counts is not None
assert final_batch.request_counts.total == 1
assert final_batch.request_counts.completed == 1
assert final_batch.output_file_id is not None
output_content = openai_client.files.content(final_batch.output_file_id)
if isinstance(output_content, str):
output_text = output_content
else:
output_text = output_content.content.decode("utf-8")
output_lines = output_text.strip().split("\n")
assert len(output_lines) == 1
result = json.loads(output_lines[0])
assert result["custom_id"] == "success-1"
assert "response" in result
assert result["response"]["status_code"] == 200
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_output_file.deleted
if final_batch.error_file_id is not None:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted

View file

@ -6,15 +6,17 @@
import inspect import inspect
import itertools import itertools
import os import os
import platform
import textwrap import textwrap
import time import time
from pathlib import Path
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .suites import SUITE_DEFINITIONS
logger = get_logger(__name__, category="tests") logger = get_logger(__name__, category="tests")
@ -61,9 +63,22 @@ def pytest_configure(config):
key, value = env_var.split("=", 1) key, value = env_var.split("=", 1)
os.environ[key] = value os.environ[key] = value
if platform.system() == "Darwin": # Darwin is the system name for macOS suites_raw = config.getoption("--suite")
os.environ["DISABLE_CODE_SANDBOX"] = "1" suites: list[str] = []
logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS") if suites_raw:
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
if unknown:
raise pytest.UsageError(
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
)
for suite in suites:
suite_def = SUITE_DEFINITIONS.get(suite, {})
defaults: dict = suite_def.get("defaults", {})
for dest, value in defaults.items():
current = getattr(config.option, dest, None)
if not current:
setattr(config.option, dest, value)
def pytest_addoption(parser): def pytest_addoption(parser):
@ -105,16 +120,21 @@ def pytest_addoption(parser):
default=384, default=384,
help="Output dimensionality of the embedding model to use for testing. Default: 384", help="Output dimensionality of the embedding model to use for testing. Default: 384",
) )
parser.addoption(
"--record-responses",
action="store_true",
help="Record new API responses instead of using cached ones.",
)
parser.addoption( parser.addoption(
"--report", "--report",
help="Path where the test report should be written, e.g. --report=/path/to/report.md", help="Path where the test report should be written, e.g. --report=/path/to/report.md",
) )
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
suite_help = (
"Comma-separated integration test suites to narrow collection and prefill defaults. "
"Available: "
f"{available_suites}. "
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
"Examples: --suite=responses or --suite=responses,vision."
)
parser.addoption("--suite", help=suite_help)
MODEL_SHORT_IDS = { MODEL_SHORT_IDS = {
"meta-llama/Llama-3.2-3B-Instruct": "3B", "meta-llama/Llama-3.2-3B-Instruct": "3B",
@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc):
pytest_plugins = ["tests.integration.fixtures.common"] pytest_plugins = ["tests.integration.fixtures.common"]
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
"""Skip collecting paths outside the selected suite roots for speed."""
suites_raw = config.getoption("--suite")
if not suites_raw:
return False
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
roots: list[str] = []
for name in names:
suite_def = SUITE_DEFINITIONS.get(name)
if suite_def:
roots.extend(suite_def.get("roots", []))
if not roots:
return False
p = Path(str(path)).resolve()
# Only constrain within tests/integration to avoid ignoring unrelated tests
integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve()
if not p.is_relative_to(integration_root):
return False
for r in roots:
rp = (Path(str(config.rootpath)) / r).resolve()
if rp.is_file():
# Allow the exact file and any ancestor directories so pytest can walk into it.
if p == rp:
return False
if p.is_dir() and rp.is_relative_to(p):
return False
else:
# Allow anything inside an allowed directory
if p.is_relative_to(rp):
return False
return True

View file

@ -58,6 +58,15 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.") pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
def skip_if_doesnt_support_n(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"remote::sambanova",
"remote::ollama",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id) provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in ( if provider.provider_type in (
@ -262,10 +271,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
) )
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case): def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
skip_if_doesnt_support_n(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
tc = TestCase(test_case) tc = TestCase(test_case)
question = tc["question"] question = tc["question"]

View file

@ -0,0 +1,42 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Say completions",
"max_tokens": 20
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-271",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "You want me to respond with a completion, but you didn't specify what I should complete. Could"
}
],
"created": 1756846620,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 28,
"total_tokens": 48,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

Before

Width:  |  Height:  |  Size: 108 KiB

After

Width:  |  Height:  |  Size: 108 KiB

Before After
Before After

View file

Before

Width:  |  Height:  |  Size: 148 KiB

After

Width:  |  Height:  |  Size: 148 KiB

Before After
Before After

View file

Before

Width:  |  Height:  |  Size: 139 KiB

After

Width:  |  Height:  |  Size: 139 KiB

Before After
Before After

View file

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest.
# For example:
#
# ```bash
# pytest tests/integration/ --suite=vision
# ```
#
# Each suite can:
# - restrict collection to specific roots (dirs or files)
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
from pathlib import Path
this_dir = Path(__file__).parent
default_roots = [
str(p)
for p in this_dir.glob("*")
if p.is_dir()
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
]
SUITE_DEFINITIONS: dict[str, dict] = {
"base": {
"description": "Base suite that includes most tests but runs them with a text Ollama model",
"roots": default_roots,
"defaults": {
"text_model": "ollama/llama3.2:3b-instruct-fp16",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"responses": {
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
"roots": ["tests/integration/responses"],
"defaults": {
"text_model": "openai/gpt-4o",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"vision": {
"description": "Suite that includes only the vision tests",
"roots": ["tests/integration/inference/test_vision_inference.py"],
"defaults": {
"vision_model": "ollama/llama3.2-vision:11b",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
}

View file

@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models):
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension): def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
# Register a memory bank first vector_db_name = "test_vector_db"
vector_db_id = "test_vector_db" register_response = client_with_empty_registry.vector_dbs.register(
client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_name,
vector_db_id=vector_db_id,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
# Retrieve the memory bank and validate its properties # Retrieve the memory bank and validate its properties
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id) response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id)
assert response is not None assert response is not None
assert response.identifier == vector_db_id assert response.identifier == actual_vector_db_id
assert response.embedding_model == embedding_model_id assert response.embedding_model == embedding_model_id
assert response.provider_resource_id == vector_db_id assert response.identifier.startswith("vs_")
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension): def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_vector_db" vector_db_name = "test_vector_db"
client_with_empty_registry.vector_dbs.register( response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] actual_vector_db_id = response.identifier
assert vector_dbs_after_register == [vector_db_id] assert actual_vector_db_id.startswith("vs_")
assert actual_vector_db_id != vector_db_name
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id) vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_dbs_after_register == [actual_vector_db_id]
vector_stores = client_with_empty_registry.vector_stores.list()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.id == actual_vector_db_id
assert vector_store.name == vector_db_name
client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id)
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert len(vector_dbs) == 0 assert len(vector_dbs) == 0
@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
], ],
) )
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case): def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
vector_db_id = "test_vector_db" vector_db_name = "test_vector_db"
client_with_empty_registry.vector_dbs.register( register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunks=sample_chunks, chunks=sample_chunks,
) )
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="What is the capital of France?", query="What is the capital of France?",
) )
assert response is not None assert response is not None
@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
query, expected_doc_id = test_case query, expected_doc_id = test_case
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query=query, query=query,
) )
assert response is not None assert response is not None
@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
"remote::qdrant": {"score_threshold": -1.0}, "remote::qdrant": {"score_threshold": -1.0},
"inline::qdrant": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0},
} }
vector_db_id = "test_precomputed_embeddings_db" vector_db_name = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register( register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
chunks_with_embeddings = [ chunks_with_embeddings = [
Chunk( Chunk(
content="This is a test chunk with precomputed embedding.", content="This is a test chunk with precomputed embedding.",
@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="precomputed embedding test", query="precomputed embedding test",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )
@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
"remote::qdrant": {"score_threshold": 0.0}, "remote::qdrant": {"score_threshold": 0.0},
"inline::qdrant": {"score_threshold": 0.0}, "inline::qdrant": {"score_threshold": 0.0},
} }
vector_db_id = "test_precomputed_embeddings_db" vector_db_name = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register( register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
chunks_with_embeddings = [ chunks_with_embeddings = [
Chunk( Chunk(
content="duplicate", content="duplicate",
@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="duplicate", query="duplicate",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )

View file

@ -146,6 +146,20 @@ class VectorDBImpl(Impl):
async def unregister_vector_db(self, vector_db_id: str): async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id return vector_db_id
async def openai_create_vector_store(self, **kwargs):
import time
import uuid
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
return VectorStoreObject(
id=vector_store_id,
name=kwargs.get("name", vector_store_id),
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def test_models_routing_table(cached_disk_dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
) )
# Register multiple vector databases and verify listing # Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2 assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data} vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids assert vdb1.identifier in vector_db_ids
assert "test-vectordb-2" in vector_db_ids assert vdb2.identifier in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb") # Verify they have UUID-based identifiers
await table.unregister_vector_db(vector_db_id="test-vectordb-2") assert vdb1.identifier.startswith("vs_")
assert vdb2.identifier.startswith("vs_")
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0 assert len(vector_dbs.data) == 0

View file

@ -7,6 +7,7 @@
# Unit tests for the routing tables vector_dbs # Unit tests for the routing tables vector_dbs
import time import time
import uuid
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
class VectorDBImpl(Impl): class VectorDBImpl(Impl):
def __init__(self): def __init__(self):
super().__init__(Api.vector_io) super().__init__(Api.vector_io)
self.vector_stores = {}
async def register_vector_db(self, vector_db: VectorDB): async def register_vector_db(self, vector_db: VectorDB):
return vector_db return vector_db
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
async def openai_delete_vector_store_file(self, vector_store_id, file_id): async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True) return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def openai_create_vector_store(
self,
name=None,
embedding_model=None,
embedding_dimension=None,
provider_id=None,
provider_vector_db_id=None,
**kwargs,
):
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_store = VectorStoreObject(
id=vector_store_id,
name=name or vector_store_id,
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
self.vector_stores[vector_store_id] = vector_store
return vector_store
async def openai_list_vector_stores(self, **kwargs):
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
return VectorStoreListResponse(
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
)
async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
n = 10
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
) )
# Register multiple vector databases and verify listing # Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") vdb_dict = {}
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2 assert len(vector_dbs.data) == len(vdb_dict)
vector_db_ids = {v.identifier for v in vector_dbs.data} vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids for k in vdb_dict:
assert "test-vectordb-2" in vector_db_ids assert vdb_dict[k].identifier in vector_db_ids
for k in vdb_dict:
await table.unregister_vector_db(vector_db_id="test-vectordb") await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0 assert len(vector_dbs.data) == 0
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
n = 10
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
vector_db_ids = {v.identifier for v in vector_dbs.data}
vector_stores = await impl.openai_list_vector_stores()
vector_store_ids = {v.id for v in vector_stores.data}
assert vector_db_ids == vector_store_ids, (
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
)
for vector_store in vector_stores.data:
vector_db = await table.get_vector_db(vector_store.id)
assert vector_store.name == vector_db.vector_db_name, (
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
)
for vector_db_id in vector_db_ids:
await table.unregister_vector_db(vector_db_id)
assert len((await table.list_vector_dbs()).data) == 0
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
user_provided_id = "my-custom-vector-db"
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
vector_stores = await impl.openai_list_vector_stores()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.name == user_provided_id
assert vector_store.id.startswith("vs_")
assert vector_store.id != user_provided_id
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 1
assert vector_dbs.data[0].identifier == vector_store.id
await table.unregister_vector_db(vector_store.id)
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl() impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user): with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
authorized_table = registered_vdb.identifier # Use the actual generated ID
# Authorized reader # Authorized reader
with request_provider_data_context({}, authorized_user): with request_provider_data_context({}, authorized_user):
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
) )
with request_provider_data_context({}, admin_user): with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
vector_db_id = registered_vdb.identifier # Use the actual generated ID
read_methods = [ read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}), (table.openai_retrieve_vector_store, (vector_db_id,), {}),

View file

@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
* test_validate_input_url_mismatch (negative) * test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative) * test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative) * test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) * test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external The tests use temporary SQLite databases for isolation and mock external
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
"endpoint", "endpoint",
[ [
"/v1/embeddings", "/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint", "/v1/invalid/endpoint",
"", "",
], ],
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
("messages", "body.messages", "invalid_request", "Messages parameter is required"), ("messages", "body.messages", "invalid_request", "Messages parameter is required"),
], ],
) )
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): async def test_validate_input_missing_parameters_chat_completions(
"""Test _validate_input when file contains request with missing required parameters.""" self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
provider.files_api.openai_retrieve_file = AsyncMock() provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock() mock_response = MagicMock()
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
assert errors[0].message == error_message assert errors[0].message == error_message
assert errors[0].param == param_path assert errors[0].param == param_path
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
],
)
async def test_validate_input_missing_parameters_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for text completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/completions",
"body": {"model": "test-model", "prompt": "Hello"},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider): async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" """Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock() provider.files_api.openai_retrieve_file = AsyncMock()