Merge branch 'main' into fix/embedding-model-type

This commit is contained in:
raghotham 2025-09-06 12:29:54 -07:00 committed by GitHub
commit 309f06829c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 1005 additions and 339 deletions

View file

@ -2,13 +2,6 @@ name: 'Run and Record Tests'
description: 'Run integration tests and handle recording/artifact upload' description: 'Run integration tests and handle recording/artifact upload'
inputs: inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
required: true
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
stack-config: stack-config:
description: 'Stack configuration to use' description: 'Stack configuration to use'
required: true required: true
@ -18,10 +11,18 @@ inputs:
inference-mode: inference-mode:
description: 'Inference mode (record or replay)' description: 'Inference mode (record or replay)'
required: true required: true
run-vision-tests: test-suite:
description: 'Whether to run vision tests' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
required: false
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
required: false
default: ''
runs: runs:
using: 'composite' using: 'composite'
@ -42,7 +43,7 @@ runs:
--test-subdirs '${{ inputs.test-subdirs }}' \ --test-subdirs '${{ inputs.test-subdirs }}' \
--test-pattern '${{ inputs.test-pattern }}' \ --test-pattern '${{ inputs.test-pattern }}' \
--inference-mode '${{ inputs.inference-mode }}' \ --inference-mode '${{ inputs.inference-mode }}' \
${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ --test-suite '${{ inputs.test-suite }}' \
| tee pytest-${{ inputs.inference-mode }}.log | tee pytest-${{ inputs.inference-mode }}.log
@ -57,12 +58,7 @@ runs:
echo "New recordings detected, committing and pushing" echo "New recordings detected, committing and pushing"
git add tests/integration/recordings/ git add tests/integration/recordings/
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})"
git commit -m "Recordings update from CI (vision)"
else
git commit -m "Recordings update from CI"
fi
git fetch origin ${{ github.ref_name }} git fetch origin ${{ github.ref_name }}
git rebase origin/${{ github.ref_name }} git rebase origin/${{ github.ref_name }}
echo "Rebased successfully" echo "Rebased successfully"

View file

@ -1,17 +1,17 @@
name: Setup Ollama name: Setup Ollama
description: Start Ollama description: Start Ollama
inputs: inputs:
run-vision-tests: test-suite:
description: 'Run vision tests: "true" or "false"' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Start Ollama - name: Start Ollama
shell: bash shell: bash
run: | run: |
if [ "${{ inputs.run-vision-tests }}" == "true" ]; then if [ "${{ inputs.test-suite }}" == "vision" ]; then
image="ollama-with-vision-model" image="ollama-with-vision-model"
else else
image="ollama-with-models" image="ollama-with-models"

View file

@ -12,10 +12,10 @@ inputs:
description: 'Provider to setup (ollama or vllm)' description: 'Provider to setup (ollama or vllm)'
required: true required: true
default: 'ollama' default: 'ollama'
run-vision-tests: test-suite:
description: 'Whether to setup provider for vision tests' description: 'Test suite to use: base, responses, vision, etc.'
required: false required: false
default: 'false' default: ''
inference-mode: inference-mode:
description: 'Inference mode (record or replay)' description: 'Inference mode (record or replay)'
required: true required: true
@ -33,7 +33,7 @@ runs:
if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }}
uses: ./.github/actions/setup-ollama uses: ./.github/actions/setup-ollama
with: with:
run-vision-tests: ${{ inputs.run-vision-tests }} test-suite: ${{ inputs.test-suite }}
- name: Setup vllm - name: Setup vllm
if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }}

View file

@ -8,7 +8,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl
| Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script |
| Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication |
| SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore |
| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode | | Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode |
| Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers |
| Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks |
| Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build |

View file

@ -1,6 +1,6 @@
name: Integration Tests (Replay) name: Integration Tests (Replay)
run-name: Run the integration test suite from tests/integration in replay mode run-name: Run the integration test suites from tests/integration in replay mode
on: on:
push: push:
@ -32,14 +32,6 @@ on:
description: 'Test against a specific provider' description: 'Test against a specific provider'
type: string type: string
default: 'ollama' default: 'ollama'
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-pattern:
description: 'Regex pattern to pass to pytest -k'
type: string
default: ''
concurrency: concurrency:
# Skip concurrency for pushes to main - each commit should be tested independently # Skip concurrency for pushes to main - each commit should be tested independently
@ -50,7 +42,7 @@ jobs:
run-replay-mode-tests: run-replay-mode-tests:
runs-on: ubuntu-latest runs-on: ubuntu-latest
name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }}
strategy: strategy:
fail-fast: false fail-fast: false
@ -61,7 +53,7 @@ jobs:
# Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12
python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }}
client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }}
run-vision-tests: [true, false] test-suite: [base, vision]
steps: steps:
- name: Checkout repository - name: Checkout repository
@ -73,15 +65,13 @@ jobs:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
client-version: ${{ matrix.client-version }} client-version: ${{ matrix.client-version }}
provider: ${{ matrix.provider }} provider: ${{ matrix.provider }}
run-vision-tests: ${{ matrix.run-vision-tests }} test-suite: ${{ matrix.test-suite }}
inference-mode: 'replay' inference-mode: 'replay'
- name: Run tests - name: Run tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}
stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }}
provider: ${{ matrix.provider }} provider: ${{ matrix.provider }}
inference-mode: 'replay' inference-mode: 'replay'
run-vision-tests: ${{ matrix.run-vision-tests }} test-suite: ${{ matrix.test-suite }}

View file

@ -10,18 +10,18 @@ run-name: Run the integration test suite from tests/integration
on: on:
workflow_dispatch: workflow_dispatch:
inputs: inputs:
test-subdirs:
description: 'Comma-separated list of test subdirectories to run'
type: string
default: ''
test-provider: test-provider:
description: 'Test against a specific provider' description: 'Test against a specific provider'
type: string type: string
default: 'ollama' default: 'ollama'
run-vision-tests: test-suite:
description: 'Whether to run vision tests' description: 'Test suite to use: base, responses, vision, etc.'
type: boolean type: string
default: false default: ''
test-subdirs:
description: 'Comma-separated list of test subdirectories to run; overrides test-suite'
type: string
default: ''
test-pattern: test-pattern:
description: 'Regex pattern to pass to pytest -k' description: 'Regex pattern to pass to pytest -k'
type: string type: string
@ -38,11 +38,11 @@ jobs:
- name: Echo workflow inputs - name: Echo workflow inputs
run: | run: |
echo "::group::Workflow Inputs" echo "::group::Workflow Inputs"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "run-vision-tests: ${{ inputs.run-vision-tests }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "branch: ${{ github.ref_name }}" echo "branch: ${{ github.ref_name }}"
echo "test-provider: ${{ inputs.test-provider }}"
echo "test-suite: ${{ inputs.test-suite }}"
echo "test-subdirs: ${{ inputs.test-subdirs }}"
echo "test-pattern: ${{ inputs.test-pattern }}"
echo "::endgroup::" echo "::endgroup::"
- name: Checkout repository - name: Checkout repository
@ -56,15 +56,15 @@ jobs:
python-version: "3.12" # Use single Python version for recording python-version: "3.12" # Use single Python version for recording
client-version: "latest" client-version: "latest"
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ inputs.test-provider || 'ollama' }}
run-vision-tests: ${{ inputs.run-vision-tests }} test-suite: ${{ inputs.test-suite }}
inference-mode: 'record' inference-mode: 'record'
- name: Run and record tests - name: Run and record tests
uses: ./.github/actions/run-and-record-tests uses: ./.github/actions/run-and-record-tests
with: with:
test-pattern: ${{ inputs.test-pattern }}
test-subdirs: ${{ inputs.test-subdirs }}
stack-config: 'server:ci-tests' # recording must be done with server since more tests are run stack-config: 'server:ci-tests' # recording must be done with server since more tests are run
provider: ${{ inputs.test-provider || 'ollama' }} provider: ${{ inputs.test-provider || 'ollama' }}
inference-mode: 'record' inference-mode: 'record'
run-vision-tests: ${{ inputs.run-vision-tests }} test-suite: ${{ inputs.test-suite }}
test-subdirs: ${{ inputs.test-subdirs }}
test-pattern: ${{ inputs.test-pattern }}

View file

@ -86,7 +86,7 @@ repos:
language: python language: python
pass_filenames: false pass_filenames: false
require_serial: true require_serial: true
files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ files: ^llama_stack/distributions/.*$|^llama_stack/providers/.*/inference/.*/models\.py$
- id: provider-codegen - id: provider-codegen
name: Provider Codegen name: Provider Codegen
additional_dependencies: additional_dependencies:

View file

@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- inference - inference
- safety
- telemetry - telemetry
- tool_runtime - tool_runtime
- vector_io - vector_io
@ -30,6 +31,11 @@ providers:
db: ${env.POSTGRES_DB:=llamastack} db: ${env.POSTGRES_DB:=llamastack}
user: ${env.POSTGRES_USER:=llamastack} user: ${env.POSTGRES_USER:=llamastack}
password: ${env.POSTGRES_PASSWORD:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack}
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents: agents:
- provider_id: meta-reference - provider_id: meta-reference
provider_type: inline::meta-reference provider_type: inline::meta-reference
@ -95,6 +101,8 @@ models:
- model_id: ${env.INFERENCE_MODEL} - model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
model_type: llm model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B}
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []
scoring_fns: [] scoring_fns: []

View file

@ -18,12 +18,13 @@ embedding_model_id = (
).identifier ).identifier
embedding_dimension = em.metadata["embedding_dimension"] embedding_dimension = em.metadata["embedding_dimension"]
_ = client.vector_dbs.register( vector_db = client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
provider_id="faiss", provider_id="faiss",
) )
vector_db_id = vector_db.identifier
source = "https://www.paulgraham.com/greatwork.html" source = "https://www.paulgraham.com/greatwork.html"
print("rag_tool> Ingesting document:", source) print("rag_tool> Ingesting document:", source)
document = RAGDocument( document = RAGDocument(
@ -35,7 +36,7 @@ document = RAGDocument(
client.tool_runtime.rag_tool.insert( client.tool_runtime.rag_tool.insert(
documents=[document], documents=[document],
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
chunk_size_in_tokens=50, chunk_size_in_tokens=100,
) )
agent = Agent( agent = Agent(
client, client,

View file

@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | | `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
## Sample Configuration ## Sample Configuration

View file

@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services.
| `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE |
| `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS |
| `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE |
| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | | `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. |
| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. |
| `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). |
## Sample Configuration ## Sample Configuration

View file

@ -527,7 +527,7 @@ class InferenceRouter(Inference):
# Store the response with the ID that will be returned to the client # Store the response with the ID that will be returned to the client
if self.store: if self.store:
await self.store.store_chat_completion(response, messages) asyncio.create_task(self.store.store_chat_completion(response, messages))
if self.telemetry: if self.telemetry:
metrics = self._construct_metrics( metrics = self._construct_metrics(
@ -855,4 +855,4 @@ class InferenceRouter(Inference):
object="chat.completion", object="chat.completion",
) )
logger.debug(f"InferenceRouter.completion_response: {final_response}") logger.debug(f"InferenceRouter.completion_response: {final_response}")
await self.store.store_chat_completion(final_response, messages) asyncio.create_task(self.store.store_chat_completion(final_response, messages))

View file

@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
provider_vector_db_id: str | None = None, provider_vector_db_id: str | None = None,
vector_db_name: str | None = None, vector_db_name: str | None = None,
) -> VectorDB: ) -> VectorDB:
provider_vector_db_id = provider_vector_db_id or vector_db_id
if provider_id is None: if provider_id is None:
if len(self.impls_by_provider_id) > 0: if len(self.impls_by_provider_id) > 0:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding)
if "embedding_dimension" not in model.metadata: if "embedding_dimension" not in model.metadata:
raise ValueError(f"Model {embedding_model} does not have an embedding dimension") raise ValueError(f"Model {embedding_model} does not have an embedding dimension")
provider = self.impls_by_provider_id[provider_id]
logger.warning(
"VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly."
)
vector_store = await provider.openai_create_vector_store(
name=vector_db_name or vector_db_id,
embedding_model=embedding_model,
embedding_dimension=model.metadata["embedding_dimension"],
provider_id=provider_id,
provider_vector_db_id=provider_vector_db_id,
)
vector_store_id = vector_store.id
actual_provider_vector_db_id = provider_vector_db_id or vector_store_id
logger.warning(
f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name"
)
vector_db_data = { vector_db_data = {
"identifier": vector_db_id, "identifier": vector_store_id,
"type": ResourceType.vector_db.value, "type": ResourceType.vector_db.value,
"provider_id": provider_id, "provider_id": provider_id,
"provider_resource_id": provider_vector_db_id, "provider_resource_id": actual_provider_vector_db_id,
"embedding_model": embedding_model, "embedding_model": embedding_model,
"embedding_dimension": model.metadata["embedding_dimension"], "embedding_dimension": model.metadata["embedding_dimension"],
"vector_db_name": vector_db_name, "vector_db_name": vector_store.name,
} }
vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data)
await self.register_object(vector_db) await self.register_object(vector_db)

View file

@ -132,15 +132,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
}, },
) )
elif isinstance(exc, ConflictError): elif isinstance(exc, ConflictError):
return HTTPException(status_code=409, detail=str(exc)) return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc))
elif isinstance(exc, ResourceNotFoundError): elif isinstance(exc, ResourceNotFoundError):
return HTTPException(status_code=404, detail=str(exc)) return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc))
elif isinstance(exc, ValueError): elif isinstance(exc, ValueError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError): elif isinstance(exc, BadRequestError):
return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc))
elif isinstance(exc, PermissionError | AccessDeniedError): elif isinstance(exc, PermissionError | AccessDeniedError):
return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, ConnectionError | httpx.ConnectError):
return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc))
elif isinstance(exc, asyncio.TimeoutError | TimeoutError): elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}")
elif isinstance(exc, NotImplementedError): elif isinstance(exc, NotImplementedError):

View file

@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template() template = get_starter_distribution_template(name="ci-tests")
name = "ci-tests"
template.name = name
template.description = "CI tests for Llama Stack" template.description = "CI tests for Llama Stack"
return template return template

View file

@ -89,28 +89,28 @@ providers:
config: config:
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus} - provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus provider_type: inline::milvus
config: config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb} - provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
url: ${env.CHROMADB_URL:=} url: ${env.CHROMADB_URL:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector} - provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector provider_type: remote::pgvector
config: config:
@ -121,15 +121,15 @@ providers:
password: ${env.PGVECTOR_PASSWORD:=} password: ${env.PGVECTOR_PASSWORD:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db
files: files:
- provider_id: meta-reference-files - provider_id: meta-reference-files
provider_type: inline::localfs provider_type: inline::localfs
config: config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -89,28 +89,28 @@ providers:
config: config:
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db
- provider_id: sqlite-vec - provider_id: sqlite-vec
provider_type: inline::sqlite-vec provider_type: inline::sqlite-vec
config: config:
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db
- provider_id: ${env.MILVUS_URL:+milvus} - provider_id: ${env.MILVUS_URL:+milvus}
provider_type: inline::milvus provider_type: inline::milvus
config: config:
db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db
- provider_id: ${env.CHROMADB_URL:+chromadb} - provider_id: ${env.CHROMADB_URL:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb
config: config:
url: ${env.CHROMADB_URL:=} url: ${env.CHROMADB_URL:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db
- provider_id: ${env.PGVECTOR_DB:+pgvector} - provider_id: ${env.PGVECTOR_DB:+pgvector}
provider_type: remote::pgvector provider_type: remote::pgvector
config: config:
@ -121,15 +121,15 @@ providers:
password: ${env.PGVECTOR_PASSWORD:=} password: ${env.PGVECTOR_PASSWORD:=}
kvstore: kvstore:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db
files: files:
- provider_id: meta-reference-files - provider_id: meta-reference-files
provider_type: inline::localfs provider_type: inline::localfs
config: config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files}
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard

View file

@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut
def get_distribution_template() -> DistributionTemplate: def get_distribution_template() -> DistributionTemplate:
template = get_starter_distribution_template() template = get_starter_distribution_template(name="starter-gpu")
name = "starter-gpu"
template.name = name
template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments."
template.providers["post_training"] = [ template.providers["post_training"] = [

View file

@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]:
return inference_providers return inference_providers
def get_distribution_template() -> DistributionTemplate: def get_distribution_template(name: str = "starter") -> DistributionTemplate:
remote_inference_providers = get_remote_inference_providers() remote_inference_providers = get_remote_inference_providers()
name = "starter"
providers = { providers = {
"inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers] "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers]

View file

@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches):
# TODO: set expiration time for garbage collection # TODO: set expiration time for garbage collection
if endpoint not in ["/v1/chat/completions"]: if endpoint not in ["/v1/chat/completions", "/v1/completions"]:
raise ValueError( raise ValueError(
f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint",
) )
if completion_window != "24h": if completion_window != "24h":
@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches):
) )
valid = False valid = False
for param, expected_type, type_string in [ if batch.endpoint == "/v1/chat/completions":
required_params = [
("model", str, "a string"), ("model", str, "a string"),
# messages is specific to /v1/chat/completions # messages is specific to /v1/chat/completions
# we could skip validating messages here and let inference fail. however, # we could skip validating messages here and let inference fail. however,
# that would be a very expensive way to find out messages is wrong. # that would be a very expensive way to find out messages is wrong.
("messages", list, "an array"), # TODO: allow messages to be a string? ("messages", list, "an array"), # TODO: allow messages to be a string?
]: ]
else: # /v1/completions
required_params = [
("model", str, "a string"),
("prompt", str, "a string"), # TODO: allow prompt to be a list of strings??
]
for param, expected_type, type_string in required_params:
if param not in body: if param not in body:
errors.append( errors.append(
BatchError( BatchError(
@ -591,6 +599,7 @@ class ReferenceBatchesImpl(Batches):
try: try:
# TODO(SECURITY): review body for security issues # TODO(SECURITY): review body for security issues
if request.url == "/v1/chat/completions":
request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]]
chat_response = await self.inference_api.openai_chat_completion(**request.body) chat_response = await self.inference_api.openai_chat_completion(**request.body)
@ -605,6 +614,22 @@ class ReferenceBatchesImpl(Batches):
"body": chat_response.model_dump_json(), "body": chat_response.model_dump_json(),
}, },
} }
else: # /v1/completions
completion_response = await self.inference_api.openai_completion(**request.body)
# this is for mypy, we don't allow streaming so we'll get the right type
assert hasattr(completion_response, "model_dump_json"), (
"Completion response must have model_dump_json method"
)
return {
"id": request_id,
"custom_id": request.custom_id,
"response": {
"status_code": 200,
"request_id": request_id,
"body": completion_response.model_dump_json(),
},
}
except Exception as e: except Exception as e:
logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}")
return { return {

View file

@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl from .memory import MemoryToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -5,10 +5,15 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import base64
import io
import mimetypes
import secrets import secrets
import string import string
from typing import Any from typing import Any
import httpx
from fastapi import UploadFile
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import (
InterleavedContentItem, InterleavedContentItem,
TextContentItem, TextContentItem,
) )
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolDefsResponse, ListToolDefsResponse,
@ -30,13 +36,18 @@ from llama_stack.apis.tools import (
ToolParameter, ToolParameter,
ToolRuntime, ToolRuntime,
) )
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.apis.vector_io import (
QueryChunksResponse,
VectorIO,
VectorStoreChunkingStrategyStatic,
VectorStoreChunkingStrategyStaticConfig,
)
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
content_from_doc, content_from_doc,
make_overlapped_chunks, parse_data_url,
) )
from .config import RagToolRuntimeConfig from .config import RagToolRuntimeConfig
@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
config: RagToolRuntimeConfig, config: RagToolRuntimeConfig,
vector_io_api: VectorIO, vector_io_api: VectorIO,
inference_api: Inference, inference_api: Inference,
files_api: Files,
): ):
self.config = config self.config = config
self.vector_io_api = vector_io_api self.vector_io_api = vector_io_api
self.inference_api = inference_api self.inference_api = inference_api
self.files_api = files_api
async def initialize(self): async def initialize(self):
pass pass
@ -78,26 +91,49 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_id: str, vector_db_id: str,
chunk_size_in_tokens: int = 512, chunk_size_in_tokens: int = 512,
) -> None: ) -> None:
chunks = [] if not documents:
for doc in documents:
content = await content_from_doc(doc)
# TODO: we should add enrichment here as URLs won't be added to the metadata by default
chunks.extend(
make_overlapped_chunks(
doc.document_id,
content,
chunk_size_in_tokens,
chunk_size_in_tokens // 4,
doc.metadata,
)
)
if not chunks:
return return
await self.vector_io_api.insert_chunks( for doc in documents:
chunks=chunks, if isinstance(doc.content, URL):
vector_db_id=vector_db_id, if doc.content.uri.startswith("data:"):
parts = parse_data_url(doc.content.uri)
file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode()
mime_type = parts["mimetype"]
else:
async with httpx.AsyncClient() as client:
response = await client.get(doc.content.uri)
file_data = response.content
mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream")
else:
content_str = await content_from_doc(doc)
file_data = content_str.encode("utf-8")
mime_type = doc.mime_type or "text/plain"
file_extension = mimetypes.guess_extension(mime_type) or ".txt"
filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}")
file_obj = io.BytesIO(file_data)
file_obj.name = filename
upload_file = UploadFile(file=file_obj, filename=filename)
created_file = await self.files_api.openai_upload_file(
file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS
)
chunking_strategy = VectorStoreChunkingStrategyStatic(
static=VectorStoreChunkingStrategyStaticConfig(
max_chunk_size_tokens=chunk_size_in_tokens,
chunk_overlap_tokens=chunk_size_in_tokens // 4,
)
)
await self.vector_io_api.openai_attach_file_to_vector_store(
vector_store_id=vector_db_id,
file_id=created_file.id,
attributes=doc.metadata,
chunking_strategy=chunking_strategy,
) )
async def query( async def query(

View file

@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]:
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="fireworks", adapter_type="fireworks",
pip_packages=[ pip_packages=[
"fireworks-ai<=0.18.0", "fireworks-ai<=0.17.16",
], ],
module="llama_stack.providers.remote.inference.fireworks", module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
@ -207,7 +207,7 @@ def available_providers() -> list[ProviderSpec]:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="gemini", adapter_type="gemini",
pip_packages=["litellm"], pip_packages=["litellm", "openai"],
module="llama_stack.providers.remote.inference.gemini", module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
@ -270,7 +270,7 @@ Available Models:
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(
adapter_type="sambanova", adapter_type="sambanova",
pip_packages=["litellm"], pip_packages=["litellm", "openai"],
module="llama_stack.providers.remote.inference.sambanova", module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",

View file

@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]:
], ],
module="llama_stack.providers.inline.tool_runtime.rag", module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference], api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
), ),
remote_provider_spec( remote_provider_spec(

View file

@ -5,12 +5,13 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import GeminiConfig from .config import GeminiConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
class GeminiInferenceAdapter(LiteLLMOpenAIMixin): class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
def __init__(self, config: GeminiConfig) -> None: def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__( LiteLLMOpenAIMixin.__init__(
self, self,
@ -21,6 +22,11 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
) )
self.config = config self.config = config
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self):
return "https://generativelanguage.googleapis.com/v1beta/openai/"
async def initialize(self) -> None: async def initialize(self) -> None:
await super().initialize() await super().initialize()

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import SambaNovaImplConfig from .config import SambaNovaImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):
"""
SambaNova Inference Adapter for Llama Stack.
Note: The inheritance order is important here. OpenAIMixin must come before
LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability()
is used instead of LiteLLMOpenAIMixin.check_model_availability().
- OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists
- LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM
"""
def __init__(self, config: SambaNovaImplConfig): def __init__(self, config: SambaNovaImplConfig):
self.config = config self.config = config
self.environment_available_models = [] self.environment_available_models = []
@ -24,3 +37,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin):
download_images=True, # SambaNova requires base64 image encoding download_images=True, # SambaNova requires base64 image encoding
json_schema_strict=False, # SambaNova doesn't support strict=True yet json_schema_strict=False, # SambaNova doesn't support strict=True yet
) )
# Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin
get_api_key = LiteLLMOpenAIMixin.get_api_key
def get_base_url(self) -> str:
"""
Get the base URL for OpenAI mixin.
:return: The SambaNova base URL
"""
return self.config.url

View file

@ -4,53 +4,55 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import os
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
class BedrockBaseConfig(BaseModel): class BedrockBaseConfig(BaseModel):
aws_access_key_id: str | None = Field( aws_access_key_id: str | None = Field(
default=None, default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"),
description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID",
) )
aws_secret_access_key: str | None = Field( aws_secret_access_key: str | None = Field(
default=None, default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"),
description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY",
) )
aws_session_token: str | None = Field( aws_session_token: str | None = Field(
default=None, default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"),
description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN",
) )
region_name: str | None = Field( region_name: str | None = Field(
default=None, default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"),
description="The default AWS Region to use, for example, us-west-1 or us-west-2." description="The default AWS Region to use, for example, us-west-1 or us-west-2."
"Default use environment variable: AWS_DEFAULT_REGION", "Default use environment variable: AWS_DEFAULT_REGION",
) )
profile_name: str | None = Field( profile_name: str | None = Field(
default=None, default_factory=lambda: os.getenv("AWS_PROFILE"),
description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE", description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE",
) )
total_max_attempts: int | None = Field( total_max_attempts: int | None = Field(
default=None, default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None,
description="An integer representing the maximum number of attempts that will be made for a single request, " description="An integer representing the maximum number of attempts that will be made for a single request, "
"including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS",
) )
retry_mode: str | None = Field( retry_mode: str | None = Field(
default=None, default_factory=lambda: os.getenv("AWS_RETRY_MODE"),
description="A string representing the type of retries Boto3 will perform." description="A string representing the type of retries Boto3 will perform."
"Default use environment variable: AWS_RETRY_MODE", "Default use environment variable: AWS_RETRY_MODE",
) )
connect_timeout: float | None = Field( connect_timeout: float | None = Field(
default=60, default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " description="The time in seconds till a timeout exception is thrown when attempting to make a connection. "
"The default is 60 seconds.", "The default is 60 seconds.",
) )
read_timeout: float | None = Field( read_timeout: float | None = Field(
default=60, default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")),
description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." description="The time in seconds till a timeout exception is thrown when attempting to read from a connection."
"The default is 60 seconds.", "The default is 60 seconds.",
) )
session_ttl: int | None = Field( session_ttl: int | None = Field(
default=3600, default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")),
description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).", description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).",
) )

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import base64 import base64
import struct import struct
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin:
task_type: EmbeddingTaskType | None = None, task_type: EmbeddingTaskType | None = None,
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id)
embeddings = embedding_model.encode( embeddings = await asyncio.to_thread(
[interleaved_content_as_str(content) for content in contents], show_progress_bar=False embedding_model.encode,
[interleaved_content_as_str(content) for content in contents],
show_progress_bar=False,
) )
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin:
# Get the model and generate embeddings # Get the model and generate embeddings
model_obj = await self.model_store.get_model(model) model_obj = await self.model_store.get_model(model)
embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id)
embeddings = embedding_model.encode(input_list, show_progress_bar=False) embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False)
# Convert embeddings to the requested format # Convert embeddings to the requested format
data = [] data = []
@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin:
usage=usage, usage=usage,
) )
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model) loaded_model = EMBEDDING_MODELS.get(model)
@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin:
return loaded_model return loaded_model
log.info(f"Loading sentence transformer for {model}...") log.info(f"Loading sentence transformer for {model}...")
def _load_model():
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model) return SentenceTransformer(model)
loaded_model = await asyncio.to_thread(_load_model)
EMBEDDING_MODELS[model] = loaded_model EMBEDDING_MODELS[model] = loaded_model
return loaded_model return loaded_model

View file

@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat
raise AuthenticationRequiredError(exc) from exc raise AuthenticationRequiredError(exc) from exc
if i == len(connection_strategies) - 1: if i == len(connection_strategies) - 1:
raise raise
except* httpx.ConnectError as eg:
# Connection refused, server down, network unreachable
if i == len(connection_strategies) - 1:
error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused"
logger.error(f"MCP connection error: {error_msg}")
raise ConnectionError(error_msg) from eg
else:
logger.warning(
f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* httpx.TimeoutException as eg:
# Request timeout, server too slow
if i == len(connection_strategies) - 1:
error_msg = f"MCP server at {endpoint} timed out"
logger.error(f"MCP timeout error: {error_msg}")
raise TimeoutError(error_msg) from eg
else:
logger.warning(
f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* httpx.RequestError as eg:
# DNS resolution failures, network errors, invalid URLs
if i == len(connection_strategies) - 1:
# Get the first exception's message for the error string
exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error"
error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}"
logger.error(f"MCP network error: {error_msg}")
raise ConnectionError(error_msg) from eg
else:
logger.warning(
f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}"
)
except* McpError: except* McpError:
if i < len(connection_strategies) - 1: if i < len(connection_strategies) - 1:
logger.warning( logger.warning(

View file

@ -15,7 +15,7 @@ set -euo pipefail
BRANCH="" BRANCH=""
TEST_SUBDIRS="" TEST_SUBDIRS=""
TEST_PROVIDER="ollama" TEST_PROVIDER="ollama"
RUN_VISION_TESTS=false TEST_SUITE="base"
TEST_PATTERN="" TEST_PATTERN=""
# Help function # Help function
@ -27,9 +27,9 @@ Trigger the integration test recording workflow remotely. This way you do not ne
OPTIONS: OPTIONS:
-b, --branch BRANCH Branch to run the workflow on (defaults to current branch) -b, --branch BRANCH Branch to run the workflow on (defaults to current branch)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (REQUIRED)
-p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama) -p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama)
-v, --run-vision-tests Include vision tests in the recording -t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base)
-s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite)
-k, --test-pattern PATTERN Regex pattern to pass to pytest -k -k, --test-pattern PATTERN Regex pattern to pass to pytest -k
-h, --help Show this help message -h, --help Show this help message
@ -38,7 +38,7 @@ EXAMPLES:
$0 --test-subdirs "agents" $0 --test-subdirs "agents"
# Record tests for specific branch with vision tests # Record tests for specific branch with vision tests
$0 -b my-feature-branch --test-subdirs "inference" --run-vision-tests $0 -b my-feature-branch --test-suite vision
# Record multiple test subdirectories with specific provider # Record multiple test subdirectories with specific provider
$0 --test-subdirs "agents,inference" --test-provider vllm $0 --test-subdirs "agents,inference" --test-provider vllm
@ -71,9 +71,9 @@ while [[ $# -gt 0 ]]; do
TEST_PROVIDER="$2" TEST_PROVIDER="$2"
shift 2 shift 2
;; ;;
-v|--run-vision-tests) -t|--test-suite)
RUN_VISION_TESTS=true TEST_SUITE="$2"
shift shift 2
;; ;;
-k|--test-pattern) -k|--test-pattern)
TEST_PATTERN="$2" TEST_PATTERN="$2"
@ -92,11 +92,11 @@ while [[ $# -gt 0 ]]; do
done done
# Validate required parameters # Validate required parameters
if [[ -z "$TEST_SUBDIRS" ]]; then if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then
echo "Error: --test-subdirs is required" echo "Error: --test-subdirs or --test-suite is required"
echo "Please specify which test subdirectories to run, e.g.:" echo "Please specify which test subdirectories to run or test suite to use, e.g.:"
echo " $0 --test-subdirs \"agents,inference\"" echo " $0 --test-subdirs \"agents,inference\""
echo " $0 --test-subdirs \"inference\" --run-vision-tests" echo " $0 --test-suite vision"
echo "" echo ""
exit 1 exit 1
fi fi
@ -239,17 +239,19 @@ echo "Triggering integration test recording workflow..."
echo "Branch: $BRANCH" echo "Branch: $BRANCH"
echo "Test provider: $TEST_PROVIDER" echo "Test provider: $TEST_PROVIDER"
echo "Test subdirs: $TEST_SUBDIRS" echo "Test subdirs: $TEST_SUBDIRS"
echo "Run vision tests: $RUN_VISION_TESTS" echo "Test suite: $TEST_SUITE"
echo "Test pattern: ${TEST_PATTERN:-"(none)"}" echo "Test pattern: ${TEST_PATTERN:-"(none)"}"
echo "" echo ""
# Prepare inputs for gh workflow run # Prepare inputs for gh workflow run
INPUTS="-f test-subdirs='$TEST_SUBDIRS'" if [[ -n "$TEST_SUBDIRS" ]]; then
INPUTS="-f test-subdirs='$TEST_SUBDIRS'"
fi
if [[ -n "$TEST_PROVIDER" ]]; then if [[ -n "$TEST_PROVIDER" ]]; then
INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'" INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'"
fi fi
if [[ "$RUN_VISION_TESTS" == "true" ]]; then if [[ -n "$TEST_SUITE" ]]; then
INPUTS="$INPUTS -f run-vision-tests=true" INPUTS="$INPUTS -f test-suite='$TEST_SUITE'"
fi fi
if [[ -n "$TEST_PATTERN" ]]; then if [[ -n "$TEST_PATTERN" ]]; then
INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'" INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'"

View file

@ -16,7 +16,7 @@ STACK_CONFIG=""
PROVIDER="" PROVIDER=""
TEST_SUBDIRS="" TEST_SUBDIRS=""
TEST_PATTERN="" TEST_PATTERN=""
RUN_VISION_TESTS="false" TEST_SUITE="base"
INFERENCE_MODE="replay" INFERENCE_MODE="replay"
EXTRA_PARAMS="" EXTRA_PARAMS=""
@ -28,12 +28,16 @@ Usage: $0 [OPTIONS]
Options: Options:
--stack-config STRING Stack configuration to use (required) --stack-config STRING Stack configuration to use (required)
--provider STRING Provider to use (ollama, vllm, etc.) (required) --provider STRING Provider to use (ollama, vllm, etc.) (required)
--test-subdirs STRING Comma-separated list of test subdirectories to run (default: 'inference') --test-suite STRING Comma-separated list of test suites to run (default: 'base')
--run-vision-tests Run vision tests instead of regular tests
--inference-mode STRING Inference mode: record or replay (default: replay) --inference-mode STRING Inference mode: record or replay (default: replay)
--test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite)
--test-pattern STRING Regex pattern to pass to pytest -k --test-pattern STRING Regex pattern to pass to pytest -k
--help Show this help message --help Show this help message
Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options.
You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite.
Examples: Examples:
# Basic inference tests with ollama # Basic inference tests with ollama
$0 --stack-config server:ci-tests --provider ollama $0 --stack-config server:ci-tests --provider ollama
@ -42,7 +46,7 @@ Examples:
$0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents' $0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents'
# Vision tests with ollama # Vision tests with ollama
$0 --stack-config server:ci-tests --provider ollama --run-vision-tests $0 --stack-config server:ci-tests --provider ollama --test-suite vision
# Record mode for updating test recordings # Record mode for updating test recordings
$0 --stack-config server:ci-tests --provider ollama --inference-mode record $0 --stack-config server:ci-tests --provider ollama --inference-mode record
@ -64,9 +68,9 @@ while [[ $# -gt 0 ]]; do
TEST_SUBDIRS="$2" TEST_SUBDIRS="$2"
shift 2 shift 2
;; ;;
--run-vision-tests) --test-suite)
RUN_VISION_TESTS="true" TEST_SUITE="$2"
shift shift 2
;; ;;
--inference-mode) --inference-mode)
INFERENCE_MODE="$2" INFERENCE_MODE="$2"
@ -92,22 +96,25 @@ done
# Validate required parameters # Validate required parameters
if [[ -z "$STACK_CONFIG" ]]; then if [[ -z "$STACK_CONFIG" ]]; then
echo "Error: --stack-config is required" echo "Error: --stack-config is required"
usage
exit 1 exit 1
fi fi
if [[ -z "$PROVIDER" ]]; then if [[ -z "$PROVIDER" ]]; then
echo "Error: --provider is required" echo "Error: --provider is required"
usage exit 1
fi
if [[ -z "$TEST_SUITE" && -z "$TEST_SUBDIRS" ]]; then
echo "Error: --test-suite or --test-subdirs is required"
exit 1 exit 1
fi fi
echo "=== Llama Stack Integration Test Runner ===" echo "=== Llama Stack Integration Test Runner ==="
echo "Stack Config: $STACK_CONFIG" echo "Stack Config: $STACK_CONFIG"
echo "Provider: $PROVIDER" echo "Provider: $PROVIDER"
echo "Test Subdirs: $TEST_SUBDIRS"
echo "Vision Tests: $RUN_VISION_TESTS"
echo "Inference Mode: $INFERENCE_MODE" echo "Inference Mode: $INFERENCE_MODE"
echo "Test Suite: $TEST_SUITE"
echo "Test Subdirs: $TEST_SUBDIRS"
echo "Test Pattern: $TEST_PATTERN" echo "Test Pattern: $TEST_PATTERN"
echo "" echo ""
@ -194,56 +201,12 @@ if [[ -n "$TEST_PATTERN" ]]; then
PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN" PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN"
fi fi
# Run vision tests if specified
if [[ "$RUN_VISION_TESTS" == "true" ]]; then
echo "Running vision tests..."
set +e
pytest -s -v tests/integration/inference/test_vision_inference.py \
--stack-config="$STACK_CONFIG" \
-k "$PYTEST_PATTERN" \
--vision-model=ollama/llama3.2-vision:11b \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \
--color=yes $EXTRA_PARAMS \
--capture=tee-sys
exit_code=$?
set -e
if [ $exit_code -eq 0 ]; then
echo "✅ Vision tests completed successfully"
elif [ $exit_code -eq 5 ]; then
echo "⚠️ No vision tests collected (pattern matched no tests)"
else
echo "❌ Vision tests failed"
exit 1
fi
exit 0
fi
# Run regular tests
if [[ -z "$TEST_SUBDIRS" ]]; then
TEST_SUBDIRS=$(find tests/integration -maxdepth 1 -mindepth 1 -type d |
sed 's|tests/integration/||' |
grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" |
sort)
fi
echo "Test subdirs to run: $TEST_SUBDIRS" echo "Test subdirs to run: $TEST_SUBDIRS"
# Collect all test files for the specified test types if [[ -n "$TEST_SUBDIRS" ]]; then
TEST_FILES="" # Collect all test files for the specified test types
for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do TEST_FILES=""
# Skip certain test types for vllm provider for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do
if [[ "$PROVIDER" == "vllm" ]]; then
if [[ "$test_subdir" == "safety" ]] || [[ "$test_subdir" == "post_training" ]] || [[ "$test_subdir" == "tool_runtime" ]]; then
echo "Skipping $test_subdir for vllm provider"
continue
fi
fi
if [[ "$STACK_CONFIG" != *"server:"* ]] && [[ "$test_subdir" == "batches" ]]; then
echo "Skipping $test_subdir for library client until types are supported"
continue
fi
if [[ -d "tests/integration/$test_subdir" ]]; then if [[ -d "tests/integration/$test_subdir" ]]; then
# Find all Python test files in this directory # Find all Python test files in this directory
test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py")
@ -254,24 +217,30 @@ for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do
else else
echo "Warning: Directory tests/integration/$test_subdir does not exist" echo "Warning: Directory tests/integration/$test_subdir does not exist"
fi fi
done done
if [[ -z "$TEST_FILES" ]]; then if [[ -z "$TEST_FILES" ]]; then
echo "No test files found for the specified test types" echo "No test files found for the specified test types"
exit 1 exit 1
fi
echo ""
echo "=== Running all collected tests in a single pytest command ==="
echo "Total test files: $(echo $TEST_FILES | wc -w)"
PYTEST_TARGET="$TEST_FILES"
EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2"
else
PYTEST_TARGET="tests/integration/"
EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE"
fi fi
echo ""
echo "=== Running all collected tests in a single pytest command ==="
echo "Total test files: $(echo $TEST_FILES | wc -w)"
set +e set +e
pytest -s -v $TEST_FILES \ pytest -s -v $PYTEST_TARGET \
--stack-config="$STACK_CONFIG" \ --stack-config="$STACK_CONFIG" \
-k "$PYTEST_PATTERN" \ -k "$PYTEST_PATTERN" \
--text-model="$TEXT_MODEL" \ $EXTRA_PARAMS \
--embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ --color=yes \
--color=yes $EXTRA_PARAMS \
--capture=tee-sys --capture=tee-sys
exit_code=$? exit_code=$?
set -e set -e
@ -294,7 +263,13 @@ df -h
# stop server # stop server
if [[ "$STACK_CONFIG" == *"server:"* ]]; then if [[ "$STACK_CONFIG" == *"server:"* ]]; then
echo "Stopping Llama Stack Server..." echo "Stopping Llama Stack Server..."
kill $(lsof -i :8321 | awk 'NR>1 {print $2}') pids=$(lsof -i :8321 | awk 'NR>1 {print $2}')
if [[ -n "$pids" ]]; then
echo "Killing Llama Stack Server processes: $pids"
kill -9 $pids
else
echo "No Llama Stack Server processes found ?!"
fi
echo "Llama Stack Server stopped" echo "Llama Stack Server stopped"
fi fi

View file

@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference" ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference"
# Record with vision tests enabled # Record with vision tests enabled
./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests ./scripts/github/schedule-record-workflow.sh --test-suite vision
# Record with specific provider # Record with specific provider
./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm

View file

@ -42,6 +42,27 @@ Model parameters can be influenced by the following options:
Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped
if no model is specified. if no model is specified.
### Suites (fast selection + sane defaults)
- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly).
- Available suites:
- `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model.
- `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`.
- Explicit flags always win. For example, `--suite=responses --text-model=<X>` overrides the suites text model.
Examples:
```bash
# Fast responses run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=responses
# Fast single-file vision run with defaults
pytest -s -v tests/integration --stack-config=server:starter --suite=vision
# Combine suites and override a default
pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small
```
## Examples ## Examples
### Testing against a Server ### Testing against a Server

View file

@ -268,3 +268,58 @@ class TestBatchesIntegration:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id) deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully"
def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id):
"""Run an end-to-end batch with a single successful text completion request."""
request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20}
batch_requests = [
{
"custom_id": "success-1",
"method": "POST",
"url": "/v1/completions",
"body": request_body,
}
]
with batch_helper.create_file(batch_requests) as uploaded_file:
batch = openai_client.batches.create(
input_file_id=uploaded_file.id,
endpoint="/v1/completions",
completion_window="24h",
metadata={"test": "e2e_completions_success"},
)
final_batch = batch_helper.wait_for(
batch.id,
max_wait_time=3 * 60,
expected_statuses={"completed"},
timeout_action="skip",
)
assert final_batch.status == "completed"
assert final_batch.request_counts is not None
assert final_batch.request_counts.total == 1
assert final_batch.request_counts.completed == 1
assert final_batch.output_file_id is not None
output_content = openai_client.files.content(final_batch.output_file_id)
if isinstance(output_content, str):
output_text = output_content
else:
output_text = output_content.content.decode("utf-8")
output_lines = output_text.strip().split("\n")
assert len(output_lines) == 1
result = json.loads(output_lines[0])
assert result["custom_id"] == "success-1"
assert "response" in result
assert result["response"]["status_code"] == 200
deleted_output_file = openai_client.files.delete(final_batch.output_file_id)
assert deleted_output_file.deleted
if final_batch.error_file_id is not None:
deleted_error_file = openai_client.files.delete(final_batch.error_file_id)
assert deleted_error_file.deleted

View file

@ -6,15 +6,17 @@
import inspect import inspect
import itertools import itertools
import os import os
import platform
import textwrap import textwrap
import time import time
from pathlib import Path
import pytest import pytest
from dotenv import load_dotenv from dotenv import load_dotenv
from llama_stack.log import get_logger from llama_stack.log import get_logger
from .suites import SUITE_DEFINITIONS
logger = get_logger(__name__, category="tests") logger = get_logger(__name__, category="tests")
@ -61,9 +63,22 @@ def pytest_configure(config):
key, value = env_var.split("=", 1) key, value = env_var.split("=", 1)
os.environ[key] = value os.environ[key] = value
if platform.system() == "Darwin": # Darwin is the system name for macOS suites_raw = config.getoption("--suite")
os.environ["DISABLE_CODE_SANDBOX"] = "1" suites: list[str] = []
logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS") if suites_raw:
suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
unknown = [p for p in suites if p not in SUITE_DEFINITIONS]
if unknown:
raise pytest.UsageError(
f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}"
)
for suite in suites:
suite_def = SUITE_DEFINITIONS.get(suite, {})
defaults: dict = suite_def.get("defaults", {})
for dest, value in defaults.items():
current = getattr(config.option, dest, None)
if not current:
setattr(config.option, dest, value)
def pytest_addoption(parser): def pytest_addoption(parser):
@ -105,16 +120,21 @@ def pytest_addoption(parser):
default=384, default=384,
help="Output dimensionality of the embedding model to use for testing. Default: 384", help="Output dimensionality of the embedding model to use for testing. Default: 384",
) )
parser.addoption(
"--record-responses",
action="store_true",
help="Record new API responses instead of using cached ones.",
)
parser.addoption( parser.addoption(
"--report", "--report",
help="Path where the test report should be written, e.g. --report=/path/to/report.md", help="Path where the test report should be written, e.g. --report=/path/to/report.md",
) )
available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys()))
suite_help = (
"Comma-separated integration test suites to narrow collection and prefill defaults. "
"Available: "
f"{available_suites}. "
"Explicit CLI flags (e.g., --text-model) override suite defaults. "
"Examples: --suite=responses or --suite=responses,vision."
)
parser.addoption("--suite", help=suite_help)
MODEL_SHORT_IDS = { MODEL_SHORT_IDS = {
"meta-llama/Llama-3.2-3B-Instruct": "3B", "meta-llama/Llama-3.2-3B-Instruct": "3B",
@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc):
pytest_plugins = ["tests.integration.fixtures.common"] pytest_plugins = ["tests.integration.fixtures.common"]
def pytest_ignore_collect(path: str, config: pytest.Config) -> bool:
"""Skip collecting paths outside the selected suite roots for speed."""
suites_raw = config.getoption("--suite")
if not suites_raw:
return False
names = [p.strip() for p in str(suites_raw).split(",") if p.strip()]
roots: list[str] = []
for name in names:
suite_def = SUITE_DEFINITIONS.get(name)
if suite_def:
roots.extend(suite_def.get("roots", []))
if not roots:
return False
p = Path(str(path)).resolve()
# Only constrain within tests/integration to avoid ignoring unrelated tests
integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve()
if not p.is_relative_to(integration_root):
return False
for r in roots:
rp = (Path(str(config.rootpath)) / r).resolve()
if rp.is_file():
# Allow the exact file and any ancestor directories so pytest can walk into it.
if p == rp:
return False
if p.is_dir() and rp.is_relative_to(p):
return False
else:
# Allow anything inside an allowed directory
if p.is_relative_to(rp):
return False
return True

View file

@ -5,6 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import time
import pytest import pytest
from ..test_cases.test_case import TestCase from ..test_cases.test_case import TestCase
@ -35,6 +37,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id)
"remote::sambanova", "remote::sambanova",
"remote::tgi", "remote::tgi",
"remote::vertexai", "remote::vertexai",
"remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.")
@ -56,6 +59,18 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id):
pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.") pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.")
def skip_if_doesnt_support_n(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in (
"remote::sambanova",
"remote::ollama",
# Error code: 400 - [{'error': {'code': 400, 'message': 'Only one candidate can be specified in the
# current model', 'status': 'INVALID_ARGUMENT'}}]
"remote::gemini",
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.")
def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id):
provider = provider_from_model(client_with_models, model_id) provider = provider_from_model(client_with_models, model_id)
if provider.provider_type in ( if provider.provider_type in (
@ -260,10 +275,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex
) )
def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case): def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
skip_if_doesnt_support_n(client_with_models, text_model_id)
provider = provider_from_model(client_with_models, text_model_id)
if provider.provider_type == "remote::ollama":
pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.")
tc = TestCase(test_case) tc = TestCase(test_case)
question = tc["question"] question = tc["question"]
@ -323,8 +335,15 @@ def test_inference_store(compat_client, client_with_models, text_model_id, strea
response_id = response.id response_id = response.id
content = response.choices[0].message.content content = response.choices[0].message.content
tries = 0
while tries < 10:
responses = client.chat.completions.list(limit=1000) responses = client.chat.completions.list(limit=1000)
assert response_id in [r.id for r in responses.data] if response_id in [r.id for r in responses.data]:
break
else:
tries += 1
time.sleep(0.1)
assert tries < 10, f"Response {response_id} not found after 1 second"
retrieved_response = client.chat.completions.retrieve(response_id) retrieved_response = client.chat.completions.retrieve(response_id)
assert retrieved_response.id == response_id assert retrieved_response.id == response_id
@ -388,6 +407,18 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode
response_id = response.id response_id = response.id
content = response.choices[0].message.content content = response.choices[0].message.content
# wait for the response to be stored
tries = 0
while tries < 10:
responses = client.chat.completions.list(limit=1000)
if response_id in [r.id for r in responses.data]:
break
else:
tries += 1
time.sleep(0.1)
assert tries < 10, f"Response {response_id} not found after 1 second"
responses = client.chat.completions.list(limit=1000) responses = client.chat.completions.list(limit=1000)
assert response_id in [r.id for r in responses.data] assert response_id in [r.id for r in responses.data]

View file

@ -0,0 +1,42 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Say completions",
"max_tokens": 20
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-271",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "You want me to respond with a completion, but you didn't specify what I should complete. Could"
}
],
"created": 1756846620,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 28,
"total_tokens": 48,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

Before

Width:  |  Height:  |  Size: 108 KiB

After

Width:  |  Height:  |  Size: 108 KiB

Before After
Before After

View file

Before

Width:  |  Height:  |  Size: 148 KiB

After

Width:  |  Height:  |  Size: 148 KiB

Before After
Before After

View file

Before

Width:  |  Height:  |  Size: 139 KiB

After

Width:  |  Height:  |  Size: 139 KiB

Before After
Before After

View file

@ -0,0 +1,53 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest.
# For example:
#
# ```bash
# pytest tests/integration/ --suite=vision
# ```
#
# Each suite can:
# - restrict collection to specific roots (dirs or files)
# - provide default CLI option values (e.g. text_model, embedding_model, etc.)
from pathlib import Path
this_dir = Path(__file__).parent
default_roots = [
str(p)
for p in this_dir.glob("*")
if p.is_dir()
and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training")
]
SUITE_DEFINITIONS: dict[str, dict] = {
"base": {
"description": "Base suite that includes most tests but runs them with a text Ollama model",
"roots": default_roots,
"defaults": {
"text_model": "ollama/llama3.2:3b-instruct-fp16",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"responses": {
"description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model",
"roots": ["tests/integration/responses"],
"defaults": {
"text_model": "openai/gpt-4o",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
"vision": {
"description": "Suite that includes only the vision tests",
"roots": ["tests/integration/inference/test_vision_inference.py"],
"defaults": {
"vision_model": "ollama/llama3.2-vision:11b",
"embedding_model": "sentence-transformers/all-MiniLM-L6-v2",
},
},
}

View file

@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models):
client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id) client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id)
clear_registry() clear_registry()
try:
client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime")
except Exception:
pass
yield client_with_models yield client_with_models
# you must clean after the last test if you were running tests against
# a stateful server instance
clear_registry() clear_registry()
@ -66,12 +70,13 @@ def assert_valid_text_response(response):
def test_vector_db_insert_inline_and_query( def test_vector_db_insert_inline_and_query(
client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension
): ):
vector_db_id = "test_vector_db" vector_db_name = "test_vector_db"
client_with_empty_registry.vector_dbs.register( vector_db = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
vector_db_id = vector_db.identifier
client_with_empty_registry.tool_runtime.rag_tool.insert( client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=sample_documents, documents=sample_documents,
@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query(
# list to check memory bank is successfully registered # list to check memory bank is successfully registered
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs # VectorDB is being migrated to VectorStore, so the ID will be different
# Just check that at least one vector DB was registered
assert len(available_vector_dbs) > 0
# Use the actual registered vector_db_id for subsequent operations
actual_vector_db_id = available_vector_dbs[0]
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query(
client_with_empty_registry.tool_runtime.rag_tool.insert( client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
# Query for the name of method # Query for the name of method
response1 = client_with_empty_registry.vector_io.query( response1 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="What's the name of the fine-tunning method used?", query="What's the name of the fine-tunning method used?",
) )
assert_valid_chunk_response(response1) assert_valid_chunk_response(response1)
@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query(
# Query for the name of model # Query for the name of model
response2 = client_with_empty_registry.vector_io.query( response2 = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="Which Llama model is mentioned?", query="Which Llama model is mentioned?",
) )
assert_valid_chunk_response(response2) assert_valid_chunk_response(response2)
@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
) )
available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_db_id in available_vector_dbs # VectorDB is being migrated to VectorStore, so the ID will be different
# Just check that at least one vector DB was registered
assert len(available_vector_dbs) > 0
# Use the actual registered vector_db_id for subsequent operations
actual_vector_db_id = available_vector_dbs[0]
urls = [ urls = [
"memory_optimizations.rst", "memory_optimizations.rst",
@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
client_with_empty_registry.tool_runtime.rag_tool.insert( client_with_empty_registry.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
) )
response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id], vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?", content="What is the name of the method used for fine-tuning?",
) )
assert_valid_text_response(response_with_metadata) assert_valid_text_response(response_with_metadata)
assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content)
response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id], vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?", content="What is the name of the method used for fine-tuning?",
query_config={ query_config={
"include_metadata_in_content": True, "include_metadata_in_content": True,
@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i
with pytest.raises((ValueError, BadRequestError)): with pytest.raises((ValueError, BadRequestError)):
client_with_empty_registry.tool_runtime.rag_tool.query( client_with_empty_registry.tool_runtime.rag_tool.query(
vector_db_ids=[vector_db_id], vector_db_ids=[actual_vector_db_id],
content="What is the name of the method used for fine-tuning?", content="What is the name of the method used for fine-tuning?",
query_config={ query_config={
"chunk_template": "This should raise a ValueError because it is missing the proper template variables", "chunk_template": "This should raise a ValueError because it is missing the proper template variables",

View file

@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models):
def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension): def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension):
# Register a memory bank first vector_db_name = "test_vector_db"
vector_db_id = "test_vector_db" register_response = client_with_empty_registry.vector_dbs.register(
client_with_empty_registry.vector_dbs.register( vector_db_id=vector_db_name,
vector_db_id=vector_db_id,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
# Retrieve the memory bank and validate its properties # Retrieve the memory bank and validate its properties
response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id) response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id)
assert response is not None assert response is not None
assert response.identifier == vector_db_id assert response.identifier == actual_vector_db_id
assert response.embedding_model == embedding_model_id assert response.embedding_model == embedding_model_id
assert response.provider_resource_id == vector_db_id assert response.identifier.startswith("vs_")
def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension): def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension):
vector_db_id = "test_vector_db" vector_db_name = "test_vector_db"
client_with_empty_registry.vector_dbs.register( response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] actual_vector_db_id = response.identifier
assert vector_dbs_after_register == [vector_db_id] assert actual_vector_db_id.startswith("vs_")
assert actual_vector_db_id != vector_db_name
client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id) vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert vector_dbs_after_register == [actual_vector_db_id]
vector_stores = client_with_empty_registry.vector_stores.list()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.id == actual_vector_db_id
assert vector_store.name == vector_db_name
client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id)
vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()]
assert len(vector_dbs) == 0 assert len(vector_dbs) == 0
@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe
], ],
) )
def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case): def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case):
vector_db_id = "test_vector_db" vector_db_name = "test_vector_db"
client_with_empty_registry.vector_dbs.register( register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunks=sample_chunks, chunks=sample_chunks,
) )
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="What is the capital of France?", query="What is the capital of France?",
) )
assert response is not None assert response is not None
@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding
query, expected_doc_id = test_case query, expected_doc_id = test_case
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query=query, query=query,
) )
assert response is not None assert response is not None
@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
"remote::qdrant": {"score_threshold": -1.0}, "remote::qdrant": {"score_threshold": -1.0},
"inline::qdrant": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0},
} }
vector_db_id = "test_precomputed_embeddings_db" vector_db_name = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register( register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
chunks_with_embeddings = [ chunks_with_embeddings = [
Chunk( Chunk(
content="This is a test chunk with precomputed embedding.", content="This is a test chunk with precomputed embedding.",
@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="precomputed embedding test", query="precomputed embedding test",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )
@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
"remote::qdrant": {"score_threshold": 0.0}, "remote::qdrant": {"score_threshold": 0.0},
"inline::qdrant": {"score_threshold": 0.0}, "inline::qdrant": {"score_threshold": 0.0},
} }
vector_db_id = "test_precomputed_embeddings_db" vector_db_name = "test_precomputed_embeddings_db"
client_with_empty_registry.vector_dbs.register( register_response = client_with_empty_registry.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_name,
embedding_model=embedding_model_id, embedding_model=embedding_model_id,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
) )
actual_vector_db_id = register_response.identifier
chunks_with_embeddings = [ chunks_with_embeddings = [
Chunk( Chunk(
content="duplicate", content="duplicate",
@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb(
] ]
client_with_empty_registry.vector_io.insert( client_with_empty_registry.vector_io.insert(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
chunks=chunks_with_embeddings, chunks=chunks_with_embeddings,
) )
provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0]
response = client_with_empty_registry.vector_io.query( response = client_with_empty_registry.vector_io.query(
vector_db_id=vector_db_id, vector_db_id=actual_vector_db_id,
query="duplicate", query="duplicate",
params=vector_io_provider_params_dict.get(provider, None), params=vector_io_provider_params_dict.get(provider, None),
) )

View file

@ -146,6 +146,20 @@ class VectorDBImpl(Impl):
async def unregister_vector_db(self, vector_db_id: str): async def unregister_vector_db(self, vector_db_id: str):
return vector_db_id return vector_db_id
async def openai_create_vector_store(self, **kwargs):
import time
import uuid
from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject
vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}"
return VectorStoreObject(
id=vector_store_id,
name=kwargs.get("name", vector_store_id),
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
async def test_models_routing_table(cached_disk_dist_registry): async def test_models_routing_table(cached_disk_dist_registry):
table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
) )
# Register multiple vector databases and verify listing # Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model")
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model")
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2 assert len(vector_dbs.data) == 2
vector_db_ids = {v.identifier for v in vector_dbs.data} vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids assert vdb1.identifier in vector_db_ids
assert "test-vectordb-2" in vector_db_ids assert vdb2.identifier in vector_db_ids
await table.unregister_vector_db(vector_db_id="test-vectordb") # Verify they have UUID-based identifiers
await table.unregister_vector_db(vector_db_id="test-vectordb-2") assert vdb1.identifier.startswith("vs_")
assert vdb2.identifier.startswith("vs_")
await table.unregister_vector_db(vector_db_id=vdb1.identifier)
await table.unregister_vector_db(vector_db_id=vdb2.identifier)
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0 assert len(vector_dbs.data) == 0

View file

@ -7,6 +7,7 @@
# Unit tests for the routing tables vector_dbs # Unit tests for the routing tables vector_dbs
import time import time
import uuid
from unittest.mock import AsyncMock from unittest.mock import AsyncMock
import pytest import pytest
@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI
class VectorDBImpl(Impl): class VectorDBImpl(Impl):
def __init__(self): def __init__(self):
super().__init__(Api.vector_io) super().__init__(Api.vector_io)
self.vector_stores = {}
async def register_vector_db(self, vector_db: VectorDB): async def register_vector_db(self, vector_db: VectorDB):
return vector_db return vector_db
@ -114,8 +116,35 @@ class VectorDBImpl(Impl):
async def openai_delete_vector_store_file(self, vector_store_id, file_id): async def openai_delete_vector_store_file(self, vector_store_id, file_id):
return VectorStoreFileDeleteResponse(id=file_id, deleted=True) return VectorStoreFileDeleteResponse(id=file_id, deleted=True)
async def openai_create_vector_store(
self,
name=None,
embedding_model=None,
embedding_dimension=None,
provider_id=None,
provider_vector_db_id=None,
**kwargs,
):
vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}"
vector_store = VectorStoreObject(
id=vector_store_id,
name=name or vector_store_id,
created_at=int(time.time()),
file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0),
)
self.vector_stores[vector_store_id] = vector_store
return vector_store
async def openai_list_vector_stores(self, **kwargs):
from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse
return VectorStoreListResponse(
data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None
)
async def test_vectordbs_routing_table(cached_disk_dist_registry): async def test_vectordbs_routing_table(cached_disk_dist_registry):
n = 10
table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {})
await table.initialize() await table.initialize()
@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry):
) )
# Register multiple vector databases and verify listing # Register multiple vector databases and verify listing
await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") vdb_dict = {}
await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 2 assert len(vector_dbs.data) == len(vdb_dict)
vector_db_ids = {v.identifier for v in vector_dbs.data} vector_db_ids = {v.identifier for v in vector_dbs.data}
assert "test-vectordb" in vector_db_ids for k in vdb_dict:
assert "test-vectordb-2" in vector_db_ids assert vdb_dict[k].identifier in vector_db_ids
for k in vdb_dict:
await table.unregister_vector_db(vector_db_id="test-vectordb") await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier)
await table.unregister_vector_db(vector_db_id="test-vectordb-2")
vector_dbs = await table.list_vector_dbs() vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 0 assert len(vector_dbs.data) == 0
async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry):
n = 10
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
vdb_dict = {}
for i in range(n):
vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model")
vector_dbs = await table.list_vector_dbs()
vector_db_ids = {v.identifier for v in vector_dbs.data}
vector_stores = await impl.openai_list_vector_stores()
vector_store_ids = {v.id for v in vector_stores.data}
assert vector_db_ids == vector_store_ids, (
f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}"
)
for vector_store in vector_stores.data:
vector_db = await table.get_vector_db(vector_store.id)
assert vector_store.name == vector_db.vector_db_name, (
f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}"
)
for vector_db_id in vector_db_ids:
await table.unregister_vector_db(vector_db_id)
assert len((await table.list_vector_dbs()).data) == 0
async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry):
impl = VectorDBImpl()
table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {})
await table.initialize()
m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {})
await m_table.initialize()
await m_table.register_model(
model_id="test-model",
provider_id="test_provider",
metadata={"embedding_dimension": 128},
model_type=ModelType.embedding,
)
user_provided_id = "my-custom-vector-db"
await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model")
vector_stores = await impl.openai_list_vector_stores()
assert len(vector_stores.data) == 1
vector_store = vector_stores.data[0]
assert vector_store.name == user_provided_id
assert vector_store.id.startswith("vs_")
assert vector_store.id != user_provided_id
vector_dbs = await table.list_vector_dbs()
assert len(vector_dbs.data) == 1
assert vector_dbs.data[0].identifier == vector_store.id
await table.unregister_vector_db(vector_store.id)
async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry):
impl = VectorDBImpl() impl = VectorDBImpl()
impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") impl.openai_retrieve_vector_store = AsyncMock(return_value="OK")
@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr
authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) authorized_user = User(principal="alice", attributes={"roles": [authorized_team]})
with request_provider_data_context({}, authorized_user): with request_provider_data_context({}, authorized_user):
_ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model")
authorized_table = registered_vdb.identifier # Use the actual generated ID
# Authorized reader # Authorized reader
with request_provider_data_context({}, authorized_user): with request_provider_data_context({}, authorized_user):
@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis
) )
with request_provider_data_context({}, admin_user): with request_provider_data_context({}, admin_user):
await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model")
vector_db_id = registered_vdb.identifier # Use the actual generated ID
read_methods = [ read_methods = [
(table.openai_retrieve_vector_store, (vector_db_id,), {}), (table.openai_retrieve_vector_store, (vector_db_id,), {}),

View file

@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated:
* test_validate_input_url_mismatch (negative) * test_validate_input_url_mismatch (negative)
* test_validate_input_multiple_errors_per_request (negative) * test_validate_input_multiple_errors_per_request (negative)
* test_validate_input_invalid_request_format (negative) * test_validate_input_invalid_request_format (negative)
* test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) * test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions)
* test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions)
* test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation)
The tests use temporary SQLite databases for isolation and mock external The tests use temporary SQLite databases for isolation and mock external
@ -213,7 +214,6 @@ class TestReferenceBatchesImpl:
"endpoint", "endpoint",
[ [
"/v1/embeddings", "/v1/embeddings",
"/v1/completions",
"/v1/invalid/endpoint", "/v1/invalid/endpoint",
"", "",
], ],
@ -499,8 +499,10 @@ class TestReferenceBatchesImpl:
("messages", "body.messages", "invalid_request", "Messages parameter is required"), ("messages", "body.messages", "invalid_request", "Messages parameter is required"),
], ],
) )
async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): async def test_validate_input_missing_parameters_chat_completions(
"""Test _validate_input when file contains request with missing required parameters.""" self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for chat completions."""
provider.files_api.openai_retrieve_file = AsyncMock() provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock() mock_response = MagicMock()
@ -541,6 +543,61 @@ class TestReferenceBatchesImpl:
assert errors[0].message == error_message assert errors[0].message == error_message
assert errors[0].param == param_path assert errors[0].param == param_path
@pytest.mark.parametrize(
"param_name,param_path,error_code,error_message",
[
("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"),
("method", "method", "missing_required_parameter", "Missing required parameter: method"),
("url", "url", "missing_required_parameter", "Missing required parameter: url"),
("body", "body", "missing_required_parameter", "Missing required parameter: body"),
("model", "body.model", "invalid_request", "Model parameter is required"),
("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"),
],
)
async def test_validate_input_missing_parameters_completions(
self, provider, param_name, param_path, error_code, error_message
):
"""Test _validate_input when file contains request with missing required parameters for text completions."""
provider.files_api.openai_retrieve_file = AsyncMock()
mock_response = MagicMock()
base_request = {
"custom_id": "req-1",
"method": "POST",
"url": "/v1/completions",
"body": {"model": "test-model", "prompt": "Hello"},
}
# Remove the specific parameter being tested
if "." in param_path:
top_level, nested_param = param_path.split(".", 1)
del base_request[top_level][nested_param]
else:
del base_request[param_name]
mock_response.body = json.dumps(base_request).encode()
provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response)
batch = BatchObject(
id="batch_test",
object="batch",
endpoint="/v1/completions",
input_file_id=f"missing_{param_name}_file",
completion_window="24h",
status="validating",
created_at=1234567890,
)
errors, requests = await provider._validate_input(batch)
assert len(errors) == 1
assert len(requests) == 0
assert errors[0].code == error_code
assert errors[0].line == 1
assert errors[0].message == error_message
assert errors[0].param == param_path
async def test_validate_input_url_mismatch(self, provider): async def test_validate_input_url_mismatch(self, provider):
"""Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" """Test _validate_input when file contains request with URL that doesn't match batch endpoint."""
provider.files_api.openai_retrieve_file = AsyncMock() provider.files_api.openai_retrieve_file = AsyncMock()

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from unittest.mock import patch
from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig
class TestBedrockBaseConfig:
def test_defaults_work_without_env_vars(self):
with patch.dict(os.environ, {}, clear=True):
config = BedrockBaseConfig()
# Basic creds should be None
assert config.aws_access_key_id is None
assert config.aws_secret_access_key is None
assert config.region_name is None
# Timeouts get defaults
assert config.connect_timeout == 60.0
assert config.read_timeout == 60.0
assert config.session_ttl == 3600
def test_env_vars_get_picked_up(self):
env_vars = {
"AWS_ACCESS_KEY_ID": "AKIATEST123",
"AWS_SECRET_ACCESS_KEY": "secret123",
"AWS_DEFAULT_REGION": "us-west-2",
"AWS_MAX_ATTEMPTS": "5",
"AWS_RETRY_MODE": "adaptive",
"AWS_CONNECT_TIMEOUT": "30",
}
with patch.dict(os.environ, env_vars, clear=True):
config = BedrockBaseConfig()
assert config.aws_access_key_id == "AKIATEST123"
assert config.aws_secret_access_key == "secret123"
assert config.region_name == "us-west-2"
assert config.total_max_attempts == 5
assert config.retry_mode == "adaptive"
assert config.connect_timeout == 30.0
def test_partial_env_setup(self):
# Just setting one timeout var
with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True):
config = BedrockBaseConfig()
assert config.connect_timeout == 120.0
assert config.read_timeout == 60.0 # still default
assert config.aws_access_key_id is None
def test_bad_max_attempts_breaks(self):
with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True):
try:
BedrockBaseConfig()
raise AssertionError("Should have failed on bad int conversion")
except ValueError:
pass # expected

View file

@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti
class TestRagQuery: class TestRagQuery:
async def test_query_raises_on_empty_vector_db_ids(self): async def test_query_raises_on_empty_vector_db_ids(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
with pytest.raises(ValueError): with pytest.raises(ValueError):
await rag_tool.query(content=MagicMock(), vector_db_ids=[]) await rag_tool.query(content=MagicMock(), vector_db_ids=[])
async def test_query_chunk_metadata_handling(self): async def test_query_chunk_metadata_handling(self):
rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) rag_tool = MemoryToolRuntimeImpl(
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
)
content = "test query content" content = "test query content"
vector_db_ids = ["db1"] vector_db_ids = ["db1"]

View file

@ -113,6 +113,15 @@ class TestTranslateException:
assert result.status_code == 504 assert result.status_code == 504
assert result.detail == "Operation timed out: " assert result.detail == "Operation timed out: "
def test_translate_connection_error(self):
"""Test that ConnectionError is translated to 502 HTTP status."""
exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused")
result = translate_exception(exc)
assert isinstance(result, HTTPException)
assert result.status_code == 502
assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"
def test_translate_not_implemented_error(self): def test_translate_not_implemented_error(self):
"""Test that NotImplementedError is translated to 501 HTTP status.""" """Test that NotImplementedError is translated to 501 HTTP status."""
exc = NotImplementedError("Not implemented") exc = NotImplementedError("Not implemented")