Merge branch 'main' into prompt-api

This commit is contained in:
Francisco Arceo 2025-09-05 15:11:00 -06:00 committed by GitHub
commit 5d610de5db
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
34 changed files with 679 additions and 240 deletions

View file

@ -2,13 +2,6 @@ name: 'Run and Record Tests'
description: 'Run integration tests and handle recording/artifact upload'
inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
required: true
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
stack-config:
description: 'Stack configuration to use'
required: true
@ -18,10 +11,18 @@ inputs:
inference-mode:
description: 'Inference mode (record or replay)'
required: true
run-vision-tests:
description: 'Whether to run vision tests'
test-suite:
description: 'Test suite to use: base, responses, vision, etc.'
required: false
default: 'false'
default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
required: false
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
runs:
using: 'composite'
@ -42,7 +43,7 @@ runs:
--test-subdirs '${{ inputs.test-subdirs }}' \
--test-pattern '${{ inputs.test-pattern }}' \
--inference-mode '${{ inputs.inference-mode }}' \
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \
--test-suite '${{ inputs.test-suite }}' \
| tee pytest-${{ inputs.inference-mode }}.log
@ -57,12 +58,7 @@ runs:
echo "New recordings detected, committing and pushing"
git add tests/integration/recordings/
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
git commit -m "Recordings update from CI (vision)"
else
git commit -m "Recordings update from CI"
fi
git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})"
git fetch origin ${{ github.ref_name }}
git rebase origin/${{ github.ref_name }}
echo "Rebased successfully"

View file

@ -1,17 +1,17 @@
name: Setup Ollama
description: Start Ollama
inputs:
run-vision-tests:
description: 'Run vision tests: "true" or "false"'
test-suite:
description: 'Test suite to use: base, responses, vision, etc.'
required: false
default: 'false'
default: ''
runs:
using: "composite"
steps:
- name: Start Ollama
shell: bash
run: |
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then
if [ "${{ inputs.test-suite }}" == "vision" ]; then
image="ollama-with-vision-model"
else
image="ollama-with-models"

View file

@ -12,10 +12,10 @@ inputs:
description: 'Provider to setup (ollama or vllm)'
required: true
default: 'ollama'
run-vision-tests:
description: 'Whether to setup provider for vision tests'
test-suite:
description: 'Test suite to use: base, responses, vision, etc.'
required: false
default: 'false'
default: ''
inference-mode:
description: 'Inference mode (record or replay)'
required: true
@ -33,7 +33,7 @@ runs:
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-ollama
with:
run-vision-tests: ${{ inputs.run-vision-tests }}
test-suite: ${{ inputs.test-suite }}
- name: Setup vllm
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}

View file

@ -8,7 +8,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode |
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |

View file

@ -1,6 +1,6 @@
name: Integration Tests (Replay)
run-name: Run the integration test suite from tests/integration in replay mode
run-name: Run the integration test suites from tests/integration in replay mode
on:
push:
@ -32,14 +32,6 @@ on:
description: 'Test against a specific provider'
type: string
default: 'ollama'
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
default: ''
concurrency:
# Skip concurrency for pushes to main - each commit should be tested independently
@ -50,7 +42,7 @@ jobs:
run-replay-mode-tests:
runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }}
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }}
strategy:
fail-fast: false
@ -61,7 +53,7 @@ jobs:
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
run-vision-tests: [true, false]
test-suite: [base, vision]
steps:
- name: Checkout repository
@ -73,15 +65,13 @@ jobs:
python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }}
provider: ${{ matrix.provider }}
run-vision-tests: ${{ matrix.run-vision-tests }}
test-suite: ${{ matrix.test-suite }}
inference-mode: 'replay'
- name: Run tests
uses: ./.github/actions/run-and-record-tests
with:
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }}
inference-mode: 'replay'
run-vision-tests: ${{ matrix.run-vision-tests }}
test-suite: ${{ matrix.test-suite }}

View file

@ -10,18 +10,18 @@ run-name: Run the integration test suite from tests/integration
on:
workflow_dispatch:
inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-provider:
description: 'Test against a specific provider'
type: string
default: 'ollama'
run-vision-tests:
description: 'Whether to run vision tests'
type: boolean
default: false
test-suite:
description: 'Test suite to use: base, responses, vision, etc.'
type: string
default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
type: string
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
@ -38,11 +38,11 @@ jobs:
- name: Echo workflow inputs
run: |
echo "::group::Workflow Inputs"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "branch: ${{ github.ref_name }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "test-suite: ${{ inputs.test-suite }}"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "::endgroup::"
- name: Checkout repository
@ -56,15 +56,15 @@ jobs:
python-version: "3.12" # Use single Python version for recording
client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }}
run-vision-tests: ${{ inputs.run-vision-tests }}
test-suite: ${{ inputs.test-suite }}
inference-mode: 'record'
- name: Run and record tests
uses: ./.github/actions/run-and-record-tests
with:
test-pattern: ${{ inputs.test-pattern }}
test-subdirs: ${{ inputs.test-subdirs }}
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }}
inference-mode: 'record'
run-vision-tests: ${{ inputs.run-vision-tests }}
test-suite: ${{ inputs.test-suite }}
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}

View file

@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None,
vector_db_name: str | None = None,
) -> VectorDB:
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None:
if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0]
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = {
"identifier": vector_db_id,
"identifier": vector_store_id,
"type": ResourceType.vector_db.value,
"provider_id": provider_id,
"provider_resource_id": provider_vector_db_id,
"provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name,
"vector_db_name": vector_store.name,
}
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db)

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]:
if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint",
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
)
if completion_window != "24h":
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
)
valid = False
for param, expected_type, type_string in [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]:
if batch.endpoint == "/v1/chat/completions":
required_params = [
("model", str, "a string"),
# messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string?
]
else: # /v1/completions
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
for param, expected_type, type_string in required_params:
if param not in body:
errors.append(
BatchError(
@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches):
try:
# TODO(SECURITY): review body for security issues
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method"
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id, # TODO: should this be different?
"body": chat_response.model_dump_json(),
},
}
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
"Completion response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id,
"body": completion_response.model_dump_json(),
},
}
except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return {

View file

@ -15,7 +15,7 @@ set -euo pipefail
BRANCH=""
TEST_SUBDIRS=""
TEST_PROVIDER="ollama"
RUN_VISION_TESTS=false
TEST_SUITE="base"
TEST_PATTERN=""
# Help function
@ -27,9 +27,9 @@ Trigger the integration test recording workflow remotely. This way you do not ne
OPTIONS:
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (REQUIRED)
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama)
-v, --run-vision-tests Include vision tests in the recording
-t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k
-h, --help Show this help message
@ -38,7 +38,7 @@ EXAMPLES:
$0 --test-subdirs "agents"
# Record tests for specific branch with vision tests
$0 -b my-feature-branch --test-subdirs "inference" --run-vision-tests
$0 -b my-feature-branch --test-suite vision
# Record multiple test subdirectories with specific provider
$0 --test-subdirs "agents,inference" --test-provider vllm
@ -71,9 +71,9 @@ while [[ $# -gt 0 ]]; do
TEST_PROVIDER="$2"
shift 2
;;
-v|--run-vision-tests)
RUN_VISION_TESTS=true
shift
-t|--test-suite)
TEST_SUITE="$2"
shift 2
;;
-k|--test-pattern)
TEST_PATTERN="$2"
@ -92,11 +92,11 @@ while [[ $# -gt 0 ]]; do
done
# Validate required parameters
if [[ -z "$TEST_SUBDIRS" ]]; then
echo "Error: --test-subdirs is required"
echo "Please specify which test subdirectories to run, e.g.:"
if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
echo "Error: --test-subdirs or --test-suite is required"
echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
echo " $0 --test-subdirs \"agents,inference\""
echo " $0 --test-subdirs \"inference\" --run-vision-tests"
echo " $0 --test-suite vision"
echo ""
exit 1
fi
@ -239,17 +239,19 @@ echo "Triggering integration test recording workflow..."
echo "Branch: $BRANCH"
echo "Test provider: $TEST_PROVIDER"
echo "Test subdirs: $TEST_SUBDIRS"
echo "Run vision tests: $RUN_VISION_TESTS"
echo "Test suite: $TEST_SUITE"
echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
echo ""
# Prepare inputs for gh workflow run
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
if [[ -n "$TEST_SUBDIRS" ]]; then
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
fi
if [[ -n "$TEST_PROVIDER" ]]; then
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'"
fi
if [[ "$RUN_VISION_TESTS" == "true" ]]; then
INPUTS="$INPUTS -f run-vision-tests=true"
if [[ -n "$TEST_SUITE" ]]; then
INPUTS="$INPUTS -f test-suite='$TEST_SUITE'"
fi
if [[ -n "$TEST_PATTERN" ]]; then
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'"

View file

@ -16,7 +16,7 @@ STACK_CONFIG=""
PROVIDER=""
TEST_SUBDIRS=""
TEST_PATTERN=""
RUN_VISION_TESTS="false"
TEST_SUITE="base"
INFERENCE_MODE="replay"
EXTRA_PARAMS=""
@ -28,12 +28,16 @@ Usage: $0 [OPTIONS]
Options:
--stack-config STRING Stack configuration to use (required)
--provider STRING Provider to use (ollama, vllm, etc.) (required)
--test-subdirs STRING Comma-separated list of test subdirectories to run (default: 'inference')
--run-vision-tests Run vision tests instead of regular tests
--test-suite STRING Comma-separated list of test suites to run (default: 'base')
--inference-mode STRING Inference mode: record or replay (default: replay)
--test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--test-pattern STRING Regex pattern to pass to pytest -k
--help Show this help message
Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options.
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
Examples:
# Basic inference tests with ollama
$0 --stack-config server:ci-tests --provider ollama
@ -42,7 +46,7 @@ Examples:
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents'
# Vision tests with ollama
$0 --stack-config server:ci-tests --provider ollama --run-vision-tests
$0 --stack-config server:ci-tests --provider ollama --test-suite vision
# Record mode for updating test recordings
$0 --stack-config server:ci-tests --provider ollama --inference-mode record
@ -64,9 +68,9 @@ while [[ $# -gt 0 ]]; do
TEST_SUBDIRS="$2"
shift 2
;;
--run-vision-tests)
RUN_VISION_TESTS="true"
shift
--test-suite)
TEST_SUITE="$2"
shift 2
;;
--inference-mode)
INFERENCE_MODE="$2"
@ -92,22 +96,25 @@ done
# Validate required parameters
if [[ -z "$STACK_CONFIG" ]]; then
echo "Error: --stack-config is required"
usage
exit 1
fi
if [[ -z "$PROVIDER" ]]; then
echo "Error: --provider is required"
usage
exit 1
fi
if [[ -z "$TEST_SUITE" && -z "$TEST_SUBDIRS" ]]; then
echo "Error: --test-suite or --test-subdirs is required"
exit 1
fi
echo "=== Llama Stack Integration Test Runner ==="
echo "Stack Config: $STACK_CONFIG"
echo "Provider: $PROVIDER"
echo "Test Subdirs: $TEST_SUBDIRS"
echo "Vision Tests: $RUN_VISION_TESTS"
echo "Inference Mode: $INFERENCE_MODE"
echo "Test Suite: $TEST_SUITE"
echo "Test Subdirs: $TEST_SUBDIRS"
echo "Test Pattern: $TEST_PATTERN"
echo ""
@ -194,84 +201,46 @@ if [[ -n "$TEST_PATTERN" ]]; then
PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN"
fi
# Run vision tests if specified
if [[ "$RUN_VISION_TESTS" == "true" ]]; then
echo "Running vision tests..."
set +e
pytest -s -v tests/integration/inference/test_vision_inference.py \
--stack-config="$STACK_CONFIG" \
-k "$PYTEST_PATTERN" \
--vision-model=ollama/llama3.2-vision:11b \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
--color=yes $EXTRA_PARAMS \
--capture=tee-sys
exit_code=$?
set -e
if [ $exit_code -eq 0 ]; then
echo "✅ Vision tests completed successfully"
elif [ $exit_code -eq 5 ]; then
echo "⚠️ No vision tests collected (pattern matched no tests)"
else
echo "❌ Vision tests failed"
exit 1
fi
exit 0
fi
# Run regular tests
if [[ -z "$TEST_SUBDIRS" ]]; then
TEST_SUBDIRS=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
sed 's|tests/integration/||' |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort)
fi
echo "Test subdirs to run: $TEST_SUBDIRS"
# Collect all test files for the specified test types
TEST_FILES=""
for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do
# Skip certain test types for vllm provider
if [[ "$PROVIDER" == "vllm" ]]; then
if [[ "$test_subdir" == "safety" ]] || [[ "$test_subdir" == "post_training" ]] || [[ "$test_subdir" == "tool_runtime" ]]; then
echo "Skipping $test_subdir for vllm provider"
continue
if [[ -n "$TEST_SUBDIRS" ]]; then
# Collect all test files for the specified test types
TEST_FILES=""
for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do
if [[ -d "tests/integration/$test_subdir" ]]; then
# Find all Python test files in this directory
test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py")
if [[ -n "$test_files" ]]; then
TEST_FILES="$TEST_FILES $test_files"
echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files"
fi
else
echo "Warning: Directory tests/integration/$test_subdir does not exist"
fi
done
if [[ -z "$TEST_FILES" ]]; then
echo "No test files found for the specified test types"
exit 1
fi
if [[ "$STACK_CONFIG" != *"server:"* ]] && [[ "$test_subdir" == "batches" ]]; then
echo "Skipping $test_subdir for library client until types are supported"
continue
fi
echo ""
echo "=== Running all collected tests in a single pytest command ==="
echo "Total test files: $(echo $TEST_FILES | wc -w)"
if [[ -d "tests/integration/$test_subdir" ]]; then
# Find all Python test files in this directory
test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py")
if [[ -n "$test_files" ]]; then
TEST_FILES="$TEST_FILES $test_files"
echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files"
fi
else
echo "Warning: Directory tests/integration/$test_subdir does not exist"
fi
done
if [[ -z "$TEST_FILES" ]]; then
echo "No test files found for the specified test types"
exit 1
PYTEST_TARGET="$TEST_FILES"
EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2"
else
PYTEST_TARGET="tests/integration/"
EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
fi
echo ""
echo "=== Running all collected tests in a single pytest command ==="
echo "Total test files: $(echo $TEST_FILES | wc -w)"
set +e
pytest -s -v $TEST_FILES \
pytest -s -v $PYTEST_TARGET \
--stack-config="$STACK_CONFIG" \
-k "$PYTEST_PATTERN" \
--text-model="$TEXT_MODEL" \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
--color=yes $EXTRA_PARAMS \
$EXTRA_PARAMS \
--color=yes \
--capture=tee-sys
exit_code=$?
set -e
@ -294,7 +263,13 @@ df -h
# stop server
if [[ "$STACK_CONFIG" == *"server:"* ]]; then
echo "Stopping Llama Stack Server..."
kill $(lsof -i :8321 | awk 'NR>1 {print $2}')
pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
else
echo "No Llama Stack Server processes found ?!"
fi
echo "Llama Stack Server stopped"
fi

View file

@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
# Record with vision tests enabled
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests
./scripts/github/schedule-record-workflow.sh --test-suite vision
# Record with specific provider
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm

View file

@ -42,6 +42,27 @@ Model parameters can be influenced by the following options:
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
if no model is specified.
### Suites (fast selection + sane defaults)
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
- Available suites:
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suites text model.
Examples:
```bash
# Fast responses run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
# Fast single-file vision run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
# Combine suites and override a default
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
```
## Examples
### Testing against a Server

View file

@ -268,3 +268,58 @@ class TestBatchesIntegration:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id):
"""Run an end-to-end batch with a single successful text completion request."""
request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20}
batch_requests = [
{
"custom_id": "success-1",
"method": "POST",
"url": "/v1/completions",
"body": request_body,
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/completions",
completion_window="24h",
metadata={"test": "e2e_completions_success"},
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60,
expected_statuses={"completed"},
timeout_action="skip",
)
assert final_batch.status == "completed"
assert final_batch.request_counts is not None
assert final_batch.request_counts.total == 1
assert final_batch.request_counts.completed == 1
assert final_batch.output_file_id is not None
output_content = openai_client.files.content(final_batch.output_file_id)
if isinstance(output_content, str):
output_text = output_content
else:
output_text = output_content.content.decode("utf-8")
output_lines = output_text.strip().split("\n")
assert len(output_lines) == 1
result = json.loads(output_lines[0])
assert result["custom_id"] == "success-1"
assert "response" in result
assert result["response"]["status_code"] == 200
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_output_file.deleted
if final_batch.error_file_id is not None:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted

View file

@ -6,15 +6,17 @@
import inspect
import itertools
import os
import platform
import textwrap
import time
from pathlib import Path
import pytest
from dotenv import load_dotenv
from llama_stack.log import get_logger
from .suites import SUITE_DEFINITIONS
logger = get_logger(__name__, category="tests")
@ -61,9 +63,22 @@ def pytest_configure(config):
key, value = env_var.split("=", 1)
os.environ[key] = value
if platform.system() == "Darwin": # Darwin is the system name for macOS
os.environ["DISABLE_CODE_SANDBOX"] = "1"
logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS")
suites_raw = config.getoption("--suite")
suites: list[str] = []
if suites_raw:
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
if unknown:
raise pytest.UsageError(
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
)
for suite in suites:
suite_def = SUITE_DEFINITIONS.get(suite, {})
defaults: dict = suite_def.get("defaults", {})
for dest, value in defaults.items():
current = getattr(config.option, dest, None)
if not current:
setattr(config.option, dest, value)
def pytest_addoption(parser):
@ -105,16 +120,21 @@ def pytest_addoption(parser):
default=384,
help="Output dimensionality of the embedding model to use for testing. Default: 384",
)
parser.addoption(
"--record-responses",
action="store_true",
help="Record new API responses instead of using cached ones.",
)
parser.addoption(
"--report",
help="Path where the test report should be written, e.g. --report=/path/to/report.md",
)
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
suite_help = (
"Comma-separated integration test suites to narrow collection and prefill defaults. "
"Available: "
f"{available_suites}. "
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
"Examples: --suite=responses or --suite=responses,vision."
)
parser.addoption("--suite", help=suite_help)
MODEL_SHORT_IDS = {
"meta-llama/Llama-3.2-3B-Instruct": "3B",
@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc):
pytest_plugins = ["tests.integration.fixtures.common"]
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
"""Skip collecting paths outside the selected suite roots for speed."""
suites_raw = config.getoption("--suite")
if not suites_raw:
return False
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
roots: list[str] = []
for name in names:
suite_def = SUITE_DEFINITIONS.get(name)
if suite_def:
roots.extend(suite_def.get("roots", []))
if not roots:
return False
p = Path(str(path)).resolve()
# Only constrain within tests/integration to avoid ignoring unrelated tests
integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve()
if not p.is_relative_to(integration_root):
return False
for r in roots:
rp = (Path(str(config.rootpath)) / r).resolve()
if rp.is_file():
# Allow the exact file and any ancestor directories so pytest can walk into it.
if p == rp:
return False
if p.is_dir() and rp.is_relative_to(p):
return False
else:
# Allow anything inside an allowed directory
if p.is_relative_to(rp):
return False
return True

View file

@ -58,6 +58,15 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
def skip_if_doesnt_support_n(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"remote::sambanova",
"remote::ollama",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
@ -262,10 +271,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
)
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
skip_if_doesnt_support_n(client_with_models, text_model_id)
tc = TestCase(test_case)
question = tc["question"]

View file

@ -0,0 +1,42 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Say completions",
"max_tokens": 20
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-271",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "You want me to respond with a completion, but you didn't specify what I should complete. Could"
}
],
"created": 1756846620,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 28,
"total_tokens": 48,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

Before

Width:  |  Height:  |  Size: 108 KiB

After

Width:  |  Height:  |  Size: 108 KiB

Before After
Before After

View file

Before

Width:  |  Height:  |  Size: 148 KiB

After

Width:  |  Height:  |  Size: 148 KiB

Before After
Before After

View file

Before

Width:  |  Height:  |  Size: 139 KiB

After

Width:  |  Height:  |  Size: 139 KiB

Before After
Before After

View file

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest.
# For example:
#
# ```bash
# pytest tests/integration/ --suite=vision
# ```
#
# Each suite can:
# - restrict collection to specific roots (dirs or files)
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
from pathlib import Path
this_dir = Path(__file__).parent
default_roots = [
str(p)
for p in this_dir.glob("*")
if p.is_dir()
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
]
SUITE_DEFINITIONS: dict[str, dict] = {
"base": {
"description": "Base suite that includes most tests but runs them with a text Ollama model",
"roots": default_roots,
"defaults": {
"text_model": "ollama/llama3.2:3b-instruct-fp16",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"responses": {
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
"roots": ["tests/integration/responses"],
"defaults": {
"text_model": "openai/gpt-4o",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"vision": {
"description": "Suite that includes only the vision tests",
"roots": ["tests/integration/inference/test_vision_inference.py"],
"defaults": {
"vision_model": "ollama/llama3.2-vision:11b",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
}

View file

@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models):
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
# Register a memory bank first
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
vector_db_name = "test_vector_db"
register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
actual_vector_db_id = register_response.identifier
# Retrieve the memory bank and validate its properties
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id)
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id)
assert response is not None
assert response.identifier == vector_db_id
assert response.identifier == actual_vector_db_id
assert response.embedding_model == embedding_model_id
assert response.provider_resource_id == vector_db_id
assert response.identifier.startswith("vs_")
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
vector_db_name = "test_vector_db"
response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_dbs_after_register == [vector_db_id]
actual_vector_db_id = response.identifier
assert actual_vector_db_id.startswith("vs_")
assert actual_vector_db_id != vector_db_name
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id)
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_dbs_after_register == [actual_vector_db_id]
vector_stores = client_with_empty_registry.vector_stores.list()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.id == actual_vector_db_id
assert vector_store.name == vector_db_name
client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id)
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert len(vector_dbs) == 0
@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
],
)
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
vector_db_id = "test_vector_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
vector_db_name = "test_vector_db"
register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
actual_vector_db_id = register_response.identifier
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
chunks=sample_chunks,
)
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
query="What is the capital of France?",
)
assert response is not None
@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
query, expected_doc_id = test_case
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
query=query,
)
assert response is not None
@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
"remote::qdrant": {"score_threshold": -1.0},
"inline::qdrant": {"score_threshold": -1.0},
}
vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
vector_db_name = "test_precomputed_embeddings_db"
register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
actual_vector_db_id = register_response.identifier
chunks_with_embeddings = [
Chunk(
content="This is a test chunk with precomputed embedding.",
@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
]
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
chunks=chunks_with_embeddings,
)
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
query="precomputed embedding test",
params=vector_io_provider_params_dict.get(provider, None),
)
@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
"remote::qdrant": {"score_threshold": 0.0},
"inline::qdrant": {"score_threshold": 0.0},
}
vector_db_id = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id,
vector_db_name = "test_precomputed_embeddings_db"
register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_name,
embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension,
)
actual_vector_db_id = register_response.identifier
chunks_with_embeddings = [
Chunk(
content="duplicate",
@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
]
client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
chunks=chunks_with_embeddings,
)
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id,
vector_db_id=actual_vector_db_id,
query="duplicate",
params=vector_io_provider_params_dict.get(provider, None),
)

View file

@ -146,6 +146,20 @@ class VectorDBImpl(Impl):
async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id
async def openai_create_vector_store(self, **kwargs):
import time
import uuid
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
return VectorStoreObject(
id=vector_store_id,
name=kwargs.get("name", vector_store_id),
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
assert vdb1.identifier in vector_db_ids
assert vdb2.identifier in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
# Verify they have UUID-based identifiers
assert vdb1.identifier.startswith("vs_")
assert vdb2.identifier.startswith("vs_")
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0

View file

@ -7,6 +7,7 @@
# Unit tests for the routing tables vector_dbs
import time
import uuid
from unittest.mock import AsyncMock
import pytest
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
class VectorDBImpl(Impl):
def __init__(self):
super().__init__(Api.vector_io)
self.vector_stores = {}
async def register_vector_db(self, vector_db: VectorDB):
return vector_db
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def openai_create_vector_store(
self,
name=None,
embedding_model=None,
embedding_dimension=None,
provider_id=None,
provider_vector_db_id=None,
**kwargs,
):
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_store = VectorStoreObject(
id=vector_store_id,
name=name or vector_store_id,
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
self.vector_stores[vector_store_id] = vector_store
return vector_store
async def openai_list_vector_stores(self, **kwargs):
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
return VectorStoreListResponse(
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
)
async def test_vectordbs_routing_table(cached_disk_dist_registry):
n = 10
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize()
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
)
# Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model")
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2
assert len(vector_dbs.data) == len(vdb_dict)
vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids
assert "test-vectordb-2" in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb")
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
for k in vdb_dict:
assert vdb_dict[k].identifier in vector_db_ids
for k in vdb_dict:
await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
n = 10
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
vector_db_ids = {v.identifier for v in vector_dbs.data}
vector_stores = await impl.openai_list_vector_stores()
vector_store_ids = {v.id for v in vector_stores.data}
assert vector_db_ids == vector_store_ids, (
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
)
for vector_store in vector_stores.data:
vector_db = await table.get_vector_db(vector_store.id)
assert vector_store.name == vector_db.vector_db_name, (
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
)
for vector_db_id in vector_db_ids:
await table.unregister_vector_db(vector_db_id)
assert len((await table.list_vector_dbs()).data) == 0
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
user_provided_id = "my-custom-vector-db"
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
vector_stores = await impl.openai_list_vector_stores()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.name == user_provided_id
assert vector_store.id.startswith("vs_")
assert vector_store.id != user_provided_id
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 1
assert vector_dbs.data[0].identifier == vector_store.id
await table.unregister_vector_db(vector_store.id)
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
authorized_table = registered_vdb.identifier # Use the actual generated ID
# Authorized reader
with request_provider_data_context({}, authorized_user):
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
)
with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
vector_db_id = registered_vdb.identifier # Use the actual generated ID
read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}),

View file

@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
* test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation)
* test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
"endpoint",
[
"/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint",
"",
],
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
("messages", "body.messages", "invalid_request", "Messages parameter is required"),
],
)
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message):
"""Test _validate_input when file contains request with missing required parameters."""
async def test_validate_input_missing_parameters_chat_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
assert errors[0].message == error_message
assert errors[0].param == param_path
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
],
)
async def test_validate_input_missing_parameters_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for text completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/completions",
"body": {"model": "test-model", "prompt": "Hello"},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock()