diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index 60550cfdc..7f028b104 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -2,13 +2,6 @@ name: 'Run and Record Tests' description: 'Run integration tests and handle recording/artifact upload' inputs: - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - required: true - test-pattern: - description: 'Regex pattern to pass to pytest -k' - required: false - default: '' stack-config: description: 'Stack configuration to use' required: true @@ -18,10 +11,18 @@ inputs: inference-mode: description: 'Inference mode (record or replay)' required: true - run-vision-tests: - description: 'Whether to run vision tests' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run; overrides test-suite' + required: false + default: '' + test-pattern: + description: 'Regex pattern to pass to pytest -k' + required: false + default: '' runs: using: 'composite' @@ -42,7 +43,7 @@ runs: --test-subdirs '${{ inputs.test-subdirs }}' \ --test-pattern '${{ inputs.test-pattern }}' \ --inference-mode '${{ inputs.inference-mode }}' \ - ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ + --test-suite '${{ inputs.test-suite }}' \ | tee pytest-${{ inputs.inference-mode }}.log @@ -57,12 +58,7 @@ runs: echo "New recordings detected, committing and pushing" git add tests/integration/recordings/ - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then - git commit -m "Recordings update from CI (vision)" - else - git commit -m "Recordings update from CI" - fi - + git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})" git fetch origin ${{ github.ref_name }} git rebase origin/${{ github.ref_name }} echo "Rebased successfully" diff --git a/.github/actions/setup-ollama/action.yml b/.github/actions/setup-ollama/action.yml index e57876cb0..dc2f87e8c 100644 --- a/.github/actions/setup-ollama/action.yml +++ b/.github/actions/setup-ollama/action.yml @@ -1,17 +1,17 @@ name: Setup Ollama description: Start Ollama inputs: - run-vision-tests: - description: 'Run vision tests: "true" or "false"' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' runs: using: "composite" steps: - name: Start Ollama shell: bash run: | - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then + if [ "${{ inputs.test-suite }}" == "vision" ]; then image="ollama-with-vision-model" else image="ollama-with-models" diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index d830e3d13..3be76f009 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -12,10 +12,10 @@ inputs: description: 'Provider to setup (ollama or vllm)' required: true default: 'ollama' - run-vision-tests: - description: 'Whether to setup provider for vision tests' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' inference-mode: description: 'Inference mode (record or replay)' required: true @@ -33,7 +33,7 @@ runs: if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} uses: ./.github/actions/setup-ollama with: - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} - name: Setup vllm if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 8344d12a4..2e0df58b8 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -8,7 +8,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | -| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode | +| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 57e582b20..bb53eea2f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,6 +1,6 @@ name: Integration Tests (Replay) -run-name: Run the integration test suite from tests/integration in replay mode +run-name: Run the integration test suites from tests/integration in replay mode on: push: @@ -32,14 +32,6 @@ on: description: 'Test against a specific provider' type: string default: 'ollama' - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - type: string - default: '' - test-pattern: - description: 'Regex pattern to pass to pytest -k' - type: string - default: '' concurrency: # Skip concurrency for pushes to main - each commit should be tested independently @@ -50,7 +42,7 @@ jobs: run-replay-mode-tests: runs-on: ubuntu-latest - name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} + name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }} strategy: fail-fast: false @@ -61,7 +53,7 @@ jobs: # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} - run-vision-tests: [true, false] + test-suite: [base, vision] steps: - name: Checkout repository @@ -73,15 +65,13 @@ jobs: python-version: ${{ matrix.python-version }} client-version: ${{ matrix.client-version }} provider: ${{ matrix.provider }} - run-vision-tests: ${{ matrix.run-vision-tests }} + test-suite: ${{ matrix.test-suite }} inference-mode: 'replay' - name: Run tests uses: ./.github/actions/run-and-record-tests with: - test-subdirs: ${{ inputs.test-subdirs }} - test-pattern: ${{ inputs.test-pattern }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} provider: ${{ matrix.provider }} inference-mode: 'replay' - run-vision-tests: ${{ matrix.run-vision-tests }} + test-suite: ${{ matrix.test-suite }} diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index d4f5586e2..01797a54b 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -10,18 +10,18 @@ run-name: Run the integration test suite from tests/integration on: workflow_dispatch: inputs: - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - type: string - default: '' test-provider: description: 'Test against a specific provider' type: string default: 'ollama' - run-vision-tests: - description: 'Whether to run vision tests' - type: boolean - default: false + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' + type: string + default: '' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run; overrides test-suite' + type: string + default: '' test-pattern: description: 'Regex pattern to pass to pytest -k' type: string @@ -38,11 +38,11 @@ jobs: - name: Echo workflow inputs run: | echo "::group::Workflow Inputs" - echo "test-subdirs: ${{ inputs.test-subdirs }}" - echo "test-provider: ${{ inputs.test-provider }}" - echo "run-vision-tests: ${{ inputs.run-vision-tests }}" - echo "test-pattern: ${{ inputs.test-pattern }}" echo "branch: ${{ github.ref_name }}" + echo "test-provider: ${{ inputs.test-provider }}" + echo "test-suite: ${{ inputs.test-suite }}" + echo "test-subdirs: ${{ inputs.test-subdirs }}" + echo "test-pattern: ${{ inputs.test-pattern }}" echo "::endgroup::" - name: Checkout repository @@ -56,15 +56,15 @@ jobs: python-version: "3.12" # Use single Python version for recording client-version: "latest" provider: ${{ inputs.test-provider || 'ollama' }} - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} inference-mode: 'record' - name: Run and record tests uses: ./.github/actions/run-and-record-tests with: - test-pattern: ${{ inputs.test-pattern }} - test-subdirs: ${{ inputs.test-subdirs }} stack-config: 'server:ci-tests' # recording must be done with server since more tests are run provider: ${{ inputs.test-provider || 'ollama' }} inference-mode: 'record' - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} + test-subdirs: ${{ inputs.test-subdirs }} + test-pattern: ${{ inputs.test-pattern }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 514fe6d2e..b7880a9fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -86,7 +86,7 @@ repos: language: python pass_filenames: false require_serial: true - files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ + files: ^llama_stack/distributions/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ - id: provider-codegen name: Provider Codegen additional_dependencies: diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml index ceb1ba2d9..5a810639e 100644 --- a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo apis: - agents - inference +- safety - telemetry - tool_runtime - vector_io @@ -30,6 +31,11 @@ providers: db: ${env.POSTGRES_DB:=llamastack} user: ${env.POSTGRES_USER:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -95,6 +101,8 @@ models: - model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/docs/source/getting_started/demo_script.py b/docs/source/getting_started/demo_script.py index 777fc78c2..2ea67739f 100644 --- a/docs/source/getting_started/demo_script.py +++ b/docs/source/getting_started/demo_script.py @@ -18,12 +18,13 @@ embedding_model_id = ( ).identifier embedding_dimension = em.metadata["embedding_dimension"] -_ = client.vector_dbs.register( +vector_db = client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, provider_id="faiss", ) +vector_db_id = vector_db.identifier source = "https://www.paulgraham.com/greatwork.html" print("rag_tool> Ingesting document:", source) document = RAGDocument( @@ -35,7 +36,7 @@ document = RAGDocument( client.tool_runtime.rag_tool.insert( documents=[document], vector_db_id=vector_db_id, - chunk_size_in_tokens=50, + chunk_size_in_tokens=100, ) agent = Agent( client, diff --git a/docs/source/providers/inference/remote_bedrock.md b/docs/source/providers/inference/remote_bedrock.md index 1454c54c2..216dd4adb 100644 --- a/docs/source/providers/inference/remote_bedrock.md +++ b/docs/source/providers/inference/remote_bedrock.md @@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | -| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | -| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | +| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | +| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | ## Sample Configuration diff --git a/docs/source/providers/safety/remote_bedrock.md b/docs/source/providers/safety/remote_bedrock.md index 3c1d6bcb0..99d77dd72 100644 --- a/docs/source/providers/safety/remote_bedrock.md +++ b/docs/source/providers/safety/remote_bedrock.md @@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services. | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | -| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | -| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | +| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | +| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | ## Sample Configuration diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 8dcad85e3..045093fe0 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -527,7 +527,7 @@ class InferenceRouter(Inference): # Store the response with the ID that will be returned to the client if self.store: - await self.store.store_chat_completion(response, messages) + asyncio.create_task(self.store.store_chat_completion(response, messages)) if self.telemetry: metrics = self._construct_metrics( @@ -855,4 +855,4 @@ class InferenceRouter(Inference): object="chat.completion", ) logger.debug(f"InferenceRouter.completion_response: {final_response}") - await self.store.store_chat_completion(final_response, messages) + asyncio.create_task(self.store.store_chat_completion(final_response, messages)) diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index 00f71b4fe..497894064 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): provider_vector_db_id: str | None = None, vector_db_name: str | None = None, ) -> VectorDB: - provider_vector_db_id = provider_vector_db_id or vector_db_id if provider_id is None: if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] @@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) if "embedding_dimension" not in model.metadata: raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + + provider = self.impls_by_provider_id[provider_id] + logger.warning( + "VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly." + ) + vector_store = await provider.openai_create_vector_store( + name=vector_db_name or vector_db_id, + embedding_model=embedding_model, + embedding_dimension=model.metadata["embedding_dimension"], + provider_id=provider_id, + provider_vector_db_id=provider_vector_db_id, + ) + + vector_store_id = vector_store.id + actual_provider_vector_db_id = provider_vector_db_id or vector_store_id + logger.warning( + f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name" + ) + vector_db_data = { - "identifier": vector_db_id, + "identifier": vector_store_id, "type": ResourceType.vector_db.value, "provider_id": provider_id, - "provider_resource_id": provider_vector_db_id, + "provider_resource_id": actual_provider_vector_db_id, "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], - "vector_db_name": vector_db_name, + "vector_db_name": vector_store.name, } vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) await self.register_object(vector_db) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index d6dfc3435..288bf46e1 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -132,15 +132,17 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro }, ) elif isinstance(exc, ConflictError): - return HTTPException(status_code=409, detail=str(exc)) + return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc)) elif isinstance(exc, ResourceNotFoundError): - return HTTPException(status_code=404, detail=str(exc)) + return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc)) elif isinstance(exc, ValueError): return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) elif isinstance(exc, PermissionError | AccessDeniedError): return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") + elif isinstance(exc, ConnectionError | httpx.ConnectError): + return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc)) elif isinstance(exc, asyncio.TimeoutError | TimeoutError): return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): diff --git a/llama_stack/distributions/ci-tests/ci_tests.py b/llama_stack/distributions/ci-tests/ci_tests.py index 8fb61faca..ab102f5f3 100644 --- a/llama_stack/distributions/ci-tests/ci_tests.py +++ b/llama_stack/distributions/ci-tests/ci_tests.py @@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut def get_distribution_template() -> DistributionTemplate: - template = get_starter_distribution_template() - name = "ci-tests" - template.name = name + template = get_starter_distribution_template(name="ci-tests") template.description = "CI tests for Llama Stack" return template diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 7523df581..26a677c7a 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -89,28 +89,28 @@ providers: config: kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db - provider_id: sqlite-vec provider_type: inline::sqlite-vec config: - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db - provider_id: ${env.MILVUS_URL:+milvus} provider_type: inline::milvus config: - db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db - provider_id: ${env.CHROMADB_URL:+chromadb} provider_type: remote::chromadb config: url: ${env.CHROMADB_URL:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db - provider_id: ${env.PGVECTOR_DB:+pgvector} provider_type: remote::pgvector config: @@ -121,15 +121,15 @@ providers: password: ${env.PGVECTOR_PASSWORD:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs config: - storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files} metadata_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index 8aed61519..5d9dfcb27 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -89,28 +89,28 @@ providers: config: kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db - provider_id: sqlite-vec provider_type: inline::sqlite-vec config: - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db - provider_id: ${env.MILVUS_URL:+milvus} provider_type: inline::milvus config: - db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db - provider_id: ${env.CHROMADB_URL:+chromadb} provider_type: remote::chromadb config: url: ${env.CHROMADB_URL:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db - provider_id: ${env.PGVECTOR_DB:+pgvector} provider_type: remote::pgvector config: @@ -121,15 +121,15 @@ providers: password: ${env.PGVECTOR_PASSWORD:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs config: - storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files} metadata_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/starter-gpu/starter_gpu.py b/llama_stack/distributions/starter-gpu/starter_gpu.py index 245334749..e7efcb283 100644 --- a/llama_stack/distributions/starter-gpu/starter_gpu.py +++ b/llama_stack/distributions/starter-gpu/starter_gpu.py @@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut def get_distribution_template() -> DistributionTemplate: - template = get_starter_distribution_template() - name = "starter-gpu" - template.name = name + template = get_starter_distribution_template(name="starter-gpu") template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." template.providers["post_training"] = [ diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index a4bbc6371..2fca52700 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]: return inference_providers -def get_distribution_template() -> DistributionTemplate: +def get_distribution_template(name: str = "starter") -> DistributionTemplate: remote_inference_providers = get_remote_inference_providers() - name = "starter" providers = { "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers] diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 26f0ad15a..e049518a4 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches): # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions"]: + if endpoint not in ["/v1/chat/completions", "/v1/completions"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint", ) if completion_window != "24h": @@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches): ) valid = False - for param, expected_type, type_string in [ - ("model", str, "a string"), - # messages is specific to /v1/chat/completions - # we could skip validating messages here and let inference fail. however, - # that would be a very expensive way to find out messages is wrong. - ("messages", list, "an array"), # TODO: allow messages to be a string? - ]: + if batch.endpoint == "/v1/chat/completions": + required_params = [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ] + else: # /v1/completions + required_params = [ + ("model", str, "a string"), + ("prompt", str, "a string"), # TODO: allow prompt to be a list of strings?? + ] + + for param, expected_type, type_string in required_params: if param not in body: errors.append( BatchError( @@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches): try: # TODO(SECURITY): review body for security issues - request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] - chat_response = await self.inference_api.openai_chat_completion(**request.body) + if request.url == "/v1/chat/completions": + request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] + chat_response = await self.inference_api.openai_chat_completion(**request.body) - # this is for mypy, we don't allow streaming so we'll get the right type - assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" - return { - "id": request_id, - "custom_id": request.custom_id, - "response": { - "status_code": 200, - "request_id": request_id, # TODO: should this be different? - "body": chat_response.model_dump_json(), - }, - } + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + else: # /v1/completions + completion_response = await self.inference_api.openai_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(completion_response, "model_dump_json"), ( + "Completion response must have model_dump_json method" + ) + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, + "body": completion_response.model_dump_json(), + }, + } except Exception as e: logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") return { diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index f9a6e5c55..f9a7e7b89 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): from .memory import MemoryToolRuntimeImpl - impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a1543457b..cb526e8ee 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,10 +5,15 @@ # the root directory of this source tree. import asyncio +import base64 +import io +import mimetypes import secrets import string from typing import Any +import httpx +from fastapi import UploadFile from pydantic import TypeAdapter from llama_stack.apis.common.content_types import ( @@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContentItem, TextContentItem, ) +from llama_stack.apis.files import Files, OpenAIFilePurpose from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( ListToolDefsResponse, @@ -30,13 +36,18 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorIO, + VectorStoreChunkingStrategyStatic, + VectorStoreChunkingStrategyStaticConfig, +) from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, - make_overlapped_chunks, + parse_data_url, ) from .config import RagToolRuntimeConfig @@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti config: RagToolRuntimeConfig, vector_io_api: VectorIO, inference_api: Inference, + files_api: Files, ): self.config = config self.vector_io_api = vector_io_api self.inference_api = inference_api + self.files_api = files_api async def initialize(self): pass @@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - chunks = [] + if not documents: + return + for doc in documents: - content = await content_from_doc(doc) - # TODO: we should add enrichment here as URLs won't be added to the metadata by default - chunks.extend( - make_overlapped_chunks( - doc.document_id, - content, - chunk_size_in_tokens, - chunk_size_in_tokens // 4, - doc.metadata, + if isinstance(doc.content, URL): + if doc.content.uri.startswith("data:"): + parts = parse_data_url(doc.content.uri) + file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode() + mime_type = parts["mimetype"] + else: + async with httpx.AsyncClient() as client: + response = await client.get(doc.content.uri) + file_data = response.content + mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream") + else: + content_str = await content_from_doc(doc) + file_data = content_str.encode("utf-8") + mime_type = doc.mime_type or "text/plain" + + file_extension = mimetypes.guess_extension(mime_type) or ".txt" + filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") + + file_obj = io.BytesIO(file_data) + file_obj.name = filename + + upload_file = UploadFile(file=file_obj, filename=filename) + + created_file = await self.files_api.openai_upload_file( + file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS + ) + + chunking_strategy = VectorStoreChunkingStrategyStatic( + static=VectorStoreChunkingStrategyStaticConfig( + max_chunk_size_tokens=chunk_size_in_tokens, + chunk_overlap_tokens=chunk_size_in_tokens // 4, ) ) - if not chunks: - return - - await self.vector_io_api.insert_chunks( - chunks=chunks, - vector_db_id=vector_db_id, - ) + await self.vector_io_api.openai_attach_file_to_vector_store( + vector_store_id=vector_db_id, + file_id=created_file.id, + attributes=doc.metadata, + chunking_strategy=chunking_strategy, + ) async def query( self, diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index fb841afdf..7a95fd089 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]: adapter=AdapterSpec( adapter_type="fireworks", pip_packages=[ - "fireworks-ai<=0.18.0", + "fireworks-ai<=0.17.16", ], module="llama_stack.providers.remote.inference.fireworks", config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", @@ -207,7 +207,7 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="gemini", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.gemini", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", @@ -270,7 +270,7 @@ Available Models: api=Api.inference, adapter=AdapterSpec( adapter_type="sambanova", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.sambanova", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 661851443..5a58fa7af 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]: ], module="llama_stack.providers.inline.tool_runtime.rag", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.inference], + api_dependencies=[Api.vector_io, Api.inference, Api.files], description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", ), remote_provider_spec( diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index b6048eff7..569227fdd 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -5,12 +5,13 @@ # the root directory of this source tree. from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import GeminiConfig from .models import MODEL_ENTRIES -class GeminiInferenceAdapter(LiteLLMOpenAIMixin): +class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: GeminiConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -21,6 +22,11 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self): + return "https://generativelanguage.googleapis.com/v1beta/openai/" + async def initialize(self) -> None: await super().initialize() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 96469acac..ee3b0f648 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -4,13 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES -class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): +class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + """ + SambaNova Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability() + is used instead of LiteLLMOpenAIMixin.check_model_availability(). + + - OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists + - LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM + """ + def __init__(self, config: SambaNovaImplConfig): self.config = config self.environment_available_models = [] @@ -24,3 +37,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): download_images=True, # SambaNova requires base64 image encoding json_schema_strict=False, # SambaNova doesn't support strict=True yet ) + + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + """ + Get the base URL for OpenAI mixin. + + :return: The SambaNova base URL + """ + return self.config.url diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py index b25617d76..2745c88cb 100644 --- a/llama_stack/providers/utils/bedrock/config.py +++ b/llama_stack/providers/utils/bedrock/config.py @@ -4,53 +4,55 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os + from pydantic import BaseModel, Field class BedrockBaseConfig(BaseModel): aws_access_key_id: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"), description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", ) aws_secret_access_key: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"), description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", ) aws_session_token: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"), description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", ) region_name: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"), description="The default AWS Region to use, for example, us-west-1 or us-west-2." "Default use environment variable: AWS_DEFAULT_REGION", ) profile_name: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_PROFILE"), description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE", ) total_max_attempts: int | None = Field( - default=None, + default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None, description="An integer representing the maximum number of attempts that will be made for a single request, " "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", ) retry_mode: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_RETRY_MODE"), description="A string representing the type of retries Boto3 will perform." "Default use environment variable: AWS_RETRY_MODE", ) connect_timeout: float | None = Field( - default=60, + default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")), description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " "The default is 60 seconds.", ) read_timeout: float | None = Field( - default=60, + default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")), description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." "The default is 60 seconds.", ) session_ttl: int | None = Field( - default=3600, + default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")), description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).", ) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 65ba2854b..9bd0aa8ce 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import base64 import struct from typing import TYPE_CHECKING @@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin: task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode( - [interleaved_content_as_str(content) for content in contents], show_progress_bar=False + embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id) + embeddings = await asyncio.to_thread( + embedding_model.encode, + [interleaved_content_as_str(content) for content in contents], + show_progress_bar=False, ) return EmbeddingsResponse(embeddings=embeddings) @@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin: # Get the model and generate embeddings model_obj = await self.model_store.get_model(model) - embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) - embeddings = embedding_model.encode(input_list, show_progress_bar=False) + embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id) + embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False) # Convert embeddings to the requested format data = [] @@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin: usage=usage, ) - def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS loaded_model = EMBEDDING_MODELS.get(model) @@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin: return loaded_model log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - loaded_model = SentenceTransformer(model) + def _load_model(): + from sentence_transformers import SentenceTransformer + + return SentenceTransformer(model) + + loaded_model = await asyncio.to_thread(_load_model) EMBEDDING_MODELS[model] = loaded_model return loaded_model diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 02f7aaf8a..fc8e2f377 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat raise AuthenticationRequiredError(exc) from exc if i == len(connection_strategies) - 1: raise + except* httpx.ConnectError as eg: + # Connection refused, server down, network unreachable + if i == len(connection_strategies) - 1: + error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused" + logger.error(f"MCP connection error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.TimeoutException as eg: + # Request timeout, server too slow + if i == len(connection_strategies) - 1: + error_msg = f"MCP server at {endpoint} timed out" + logger.error(f"MCP timeout error: {error_msg}") + raise TimeoutError(error_msg) from eg + else: + logger.warning( + f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.RequestError as eg: + # DNS resolution failures, network errors, invalid URLs + if i == len(connection_strategies) - 1: + # Get the first exception's message for the error string + exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error" + error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}" + logger.error(f"MCP network error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) except* McpError: if i < len(connection_strategies) - 1: logger.warning( diff --git a/scripts/github/schedule-record-workflow.sh b/scripts/github/schedule-record-workflow.sh index e381b60b6..09e055611 100755 --- a/scripts/github/schedule-record-workflow.sh +++ b/scripts/github/schedule-record-workflow.sh @@ -15,7 +15,7 @@ set -euo pipefail BRANCH="" TEST_SUBDIRS="" TEST_PROVIDER="ollama" -RUN_VISION_TESTS=false +TEST_SUITE="base" TEST_PATTERN="" # Help function @@ -27,9 +27,9 @@ Trigger the integration test recording workflow remotely. This way you do not ne OPTIONS: -b, --branch BRANCH Branch to run the workflow on (defaults to current branch) - -s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (REQUIRED) -p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama) - -v, --run-vision-tests Include vision tests in the recording + -t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base) + -s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite) -k, --test-pattern PATTERN Regex pattern to pass to pytest -k -h, --help Show this help message @@ -38,7 +38,7 @@ EXAMPLES: $0 --test-subdirs "agents" # Record tests for specific branch with vision tests - $0 -b my-feature-branch --test-subdirs "inference" --run-vision-tests + $0 -b my-feature-branch --test-suite vision # Record multiple test subdirectories with specific provider $0 --test-subdirs "agents,inference" --test-provider vllm @@ -71,9 +71,9 @@ while [[ $# -gt 0 ]]; do TEST_PROVIDER="$2" shift 2 ;; - -v|--run-vision-tests) - RUN_VISION_TESTS=true - shift + -t|--test-suite) + TEST_SUITE="$2" + shift 2 ;; -k|--test-pattern) TEST_PATTERN="$2" @@ -92,11 +92,11 @@ while [[ $# -gt 0 ]]; do done # Validate required parameters -if [[ -z "$TEST_SUBDIRS" ]]; then - echo "Error: --test-subdirs is required" - echo "Please specify which test subdirectories to run, e.g.:" +if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then + echo "Error: --test-subdirs or --test-suite is required" + echo "Please specify which test subdirectories to run or test suite to use, e.g.:" echo " $0 --test-subdirs \"agents,inference\"" - echo " $0 --test-subdirs \"inference\" --run-vision-tests" + echo " $0 --test-suite vision" echo "" exit 1 fi @@ -239,17 +239,19 @@ echo "Triggering integration test recording workflow..." echo "Branch: $BRANCH" echo "Test provider: $TEST_PROVIDER" echo "Test subdirs: $TEST_SUBDIRS" -echo "Run vision tests: $RUN_VISION_TESTS" +echo "Test suite: $TEST_SUITE" echo "Test pattern: ${TEST_PATTERN:-"(none)"}" echo "" # Prepare inputs for gh workflow run -INPUTS="-f test-subdirs='$TEST_SUBDIRS'" +if [[ -n "$TEST_SUBDIRS" ]]; then + INPUTS="-f test-subdirs='$TEST_SUBDIRS'" +fi if [[ -n "$TEST_PROVIDER" ]]; then INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'" fi -if [[ "$RUN_VISION_TESTS" == "true" ]]; then - INPUTS="$INPUTS -f run-vision-tests=true" +if [[ -n "$TEST_SUITE" ]]; then + INPUTS="$INPUTS -f test-suite='$TEST_SUITE'" fi if [[ -n "$TEST_PATTERN" ]]; then INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'" diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index 104ba5cf3..ab7e37579 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -16,7 +16,7 @@ STACK_CONFIG="" PROVIDER="" TEST_SUBDIRS="" TEST_PATTERN="" -RUN_VISION_TESTS="false" +TEST_SUITE="base" INFERENCE_MODE="replay" EXTRA_PARAMS="" @@ -28,12 +28,16 @@ Usage: $0 [OPTIONS] Options: --stack-config STRING Stack configuration to use (required) --provider STRING Provider to use (ollama, vllm, etc.) (required) - --test-subdirs STRING Comma-separated list of test subdirectories to run (default: 'inference') - --run-vision-tests Run vision tests instead of regular tests + --test-suite STRING Comma-separated list of test suites to run (default: 'base') --inference-mode STRING Inference mode: record or replay (default: replay) + --test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite) --test-pattern STRING Regex pattern to pass to pytest -k --help Show this help message +Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options. + +You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite. + Examples: # Basic inference tests with ollama $0 --stack-config server:ci-tests --provider ollama @@ -42,7 +46,7 @@ Examples: $0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents' # Vision tests with ollama - $0 --stack-config server:ci-tests --provider ollama --run-vision-tests + $0 --stack-config server:ci-tests --provider ollama --test-suite vision # Record mode for updating test recordings $0 --stack-config server:ci-tests --provider ollama --inference-mode record @@ -64,9 +68,9 @@ while [[ $# -gt 0 ]]; do TEST_SUBDIRS="$2" shift 2 ;; - --run-vision-tests) - RUN_VISION_TESTS="true" - shift + --test-suite) + TEST_SUITE="$2" + shift 2 ;; --inference-mode) INFERENCE_MODE="$2" @@ -92,22 +96,25 @@ done # Validate required parameters if [[ -z "$STACK_CONFIG" ]]; then echo "Error: --stack-config is required" - usage exit 1 fi if [[ -z "$PROVIDER" ]]; then echo "Error: --provider is required" - usage + exit 1 +fi + +if [[ -z "$TEST_SUITE" && -z "$TEST_SUBDIRS" ]]; then + echo "Error: --test-suite or --test-subdirs is required" exit 1 fi echo "=== Llama Stack Integration Test Runner ===" echo "Stack Config: $STACK_CONFIG" echo "Provider: $PROVIDER" -echo "Test Subdirs: $TEST_SUBDIRS" -echo "Vision Tests: $RUN_VISION_TESTS" echo "Inference Mode: $INFERENCE_MODE" +echo "Test Suite: $TEST_SUITE" +echo "Test Subdirs: $TEST_SUBDIRS" echo "Test Pattern: $TEST_PATTERN" echo "" @@ -194,84 +201,46 @@ if [[ -n "$TEST_PATTERN" ]]; then PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN" fi -# Run vision tests if specified -if [[ "$RUN_VISION_TESTS" == "true" ]]; then - echo "Running vision tests..." - set +e - pytest -s -v tests/integration/inference/test_vision_inference.py \ - --stack-config="$STACK_CONFIG" \ - -k "$PYTEST_PATTERN" \ - --vision-model=ollama/llama3.2-vision:11b \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes $EXTRA_PARAMS \ - --capture=tee-sys - exit_code=$? - set -e - - if [ $exit_code -eq 0 ]; then - echo "✅ Vision tests completed successfully" - elif [ $exit_code -eq 5 ]; then - echo "⚠️ No vision tests collected (pattern matched no tests)" - else - echo "❌ Vision tests failed" - exit 1 - fi - exit 0 -fi - -# Run regular tests -if [[ -z "$TEST_SUBDIRS" ]]; then - TEST_SUBDIRS=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | - sed 's|tests/integration/||' | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | - sort) -fi echo "Test subdirs to run: $TEST_SUBDIRS" -# Collect all test files for the specified test types -TEST_FILES="" -for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do - # Skip certain test types for vllm provider - if [[ "$PROVIDER" == "vllm" ]]; then - if [[ "$test_subdir" == "safety" ]] || [[ "$test_subdir" == "post_training" ]] || [[ "$test_subdir" == "tool_runtime" ]]; then - echo "Skipping $test_subdir for vllm provider" - continue +if [[ -n "$TEST_SUBDIRS" ]]; then + # Collect all test files for the specified test types + TEST_FILES="" + for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do + if [[ -d "tests/integration/$test_subdir" ]]; then + # Find all Python test files in this directory + test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") + if [[ -n "$test_files" ]]; then + TEST_FILES="$TEST_FILES $test_files" + echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files" + fi + else + echo "Warning: Directory tests/integration/$test_subdir does not exist" fi + done + + if [[ -z "$TEST_FILES" ]]; then + echo "No test files found for the specified test types" + exit 1 fi - if [[ "$STACK_CONFIG" != *"server:"* ]] && [[ "$test_subdir" == "batches" ]]; then - echo "Skipping $test_subdir for library client until types are supported" - continue - fi + echo "" + echo "=== Running all collected tests in a single pytest command ===" + echo "Total test files: $(echo $TEST_FILES | wc -w)" - if [[ -d "tests/integration/$test_subdir" ]]; then - # Find all Python test files in this directory - test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") - if [[ -n "$test_files" ]]; then - TEST_FILES="$TEST_FILES $test_files" - echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files" - fi - else - echo "Warning: Directory tests/integration/$test_subdir does not exist" - fi -done - -if [[ -z "$TEST_FILES" ]]; then - echo "No test files found for the specified test types" - exit 1 + PYTEST_TARGET="$TEST_FILES" + EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2" +else + PYTEST_TARGET="tests/integration/" + EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE" fi -echo "" -echo "=== Running all collected tests in a single pytest command ===" -echo "Total test files: $(echo $TEST_FILES | wc -w)" - set +e -pytest -s -v $TEST_FILES \ +pytest -s -v $PYTEST_TARGET \ --stack-config="$STACK_CONFIG" \ -k "$PYTEST_PATTERN" \ - --text-model="$TEXT_MODEL" \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes $EXTRA_PARAMS \ + $EXTRA_PARAMS \ + --color=yes \ --capture=tee-sys exit_code=$? set -e @@ -294,7 +263,13 @@ df -h # stop server if [[ "$STACK_CONFIG" == *"server:"* ]]; then echo "Stopping Llama Stack Server..." - kill $(lsof -i :8321 | awk 'NR>1 {print $2}') + pids=$(lsof -i :8321 | awk 'NR>1 {print $2}') + if [[ -n "$pids" ]]; then + echo "Killing Llama Stack Server processes: $pids" + kill -9 $pids + else + echo "No Llama Stack Server processes found ?!" + fi echo "Llama Stack Server stopped" fi diff --git a/tests/README.md b/tests/README.md index 81f025f86..c00829d3e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference" # Record with vision tests enabled -./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests +./scripts/github/schedule-record-workflow.sh --test-suite vision # Record with specific provider ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm diff --git a/tests/integration/README.md b/tests/integration/README.md index d177cbebf..b05beeb98 100644 --- a/tests/integration/README.md +++ b/tests/integration/README.md @@ -42,6 +42,27 @@ Model parameters can be influenced by the following options: Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped if no model is specified. +### Suites (fast selection + sane defaults) + +- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly). +- Available suites: + - `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model. + - `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`. +- Explicit flags always win. For example, `--suite=responses --text-model=` overrides the suite’s text model. + +Examples: + +```bash +# Fast responses run with defaults +pytest -s -v tests/integration --stack-config=server:starter --suite=responses + +# Fast single-file vision run with defaults +pytest -s -v tests/integration --stack-config=server:starter --suite=vision + +# Combine suites and override a default +pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small +``` + ## Examples ### Testing against a Server diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py index 59811b7a4..d55a68bd3 100644 --- a/tests/integration/batches/test_batches.py +++ b/tests/integration/batches/test_batches.py @@ -268,3 +268,58 @@ class TestBatchesIntegration: deleted_error_file = openai_client.files.delete(final_batch.error_file_id) assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" + + def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id): + """Run an end-to-end batch with a single successful text completion request.""" + request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20} + + batch_requests = [ + { + "custom_id": "success-1", + "method": "POST", + "url": "/v1/completions", + "body": request_body, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/completions", + completion_window="24h", + metadata={"test": "e2e_completions_success"}, + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, + expected_statuses={"completed"}, + timeout_action="skip", + ) + + assert final_batch.status == "completed" + assert final_batch.request_counts is not None + assert final_batch.request_counts.total == 1 + assert final_batch.request_counts.completed == 1 + assert final_batch.output_file_id is not None + + output_content = openai_client.files.content(final_batch.output_file_id) + if isinstance(output_content, str): + output_text = output_content + else: + output_text = output_content.content.decode("utf-8") + + output_lines = output_text.strip().split("\n") + assert len(output_lines) == 1 + + result = json.loads(output_lines[0]) + assert result["custom_id"] == "success-1" + assert "response" in result + assert result["response"]["status_code"] == 200 + + deleted_output_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_output_file.deleted + + if final_batch.error_file_id is not None: + deleted_error_file = openai_client.files.delete(final_batch.error_file_id) + assert deleted_error_file.deleted diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fd9a54d04..96260fdb7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,15 +6,17 @@ import inspect import itertools import os -import platform import textwrap import time +from pathlib import Path import pytest from dotenv import load_dotenv from llama_stack.log import get_logger +from .suites import SUITE_DEFINITIONS + logger = get_logger(__name__, category="tests") @@ -61,9 +63,22 @@ def pytest_configure(config): key, value = env_var.split("=", 1) os.environ[key] = value - if platform.system() == "Darwin": # Darwin is the system name for macOS - os.environ["DISABLE_CODE_SANDBOX"] = "1" - logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS") + suites_raw = config.getoption("--suite") + suites: list[str] = [] + if suites_raw: + suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()] + unknown = [p for p in suites if p not in SUITE_DEFINITIONS] + if unknown: + raise pytest.UsageError( + f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" + ) + for suite in suites: + suite_def = SUITE_DEFINITIONS.get(suite, {}) + defaults: dict = suite_def.get("defaults", {}) + for dest, value in defaults.items(): + current = getattr(config.option, dest, None) + if not current: + setattr(config.option, dest, value) def pytest_addoption(parser): @@ -105,16 +120,21 @@ def pytest_addoption(parser): default=384, help="Output dimensionality of the embedding model to use for testing. Default: 384", ) - parser.addoption( - "--record-responses", - action="store_true", - help="Record new API responses instead of using cached ones.", - ) parser.addoption( "--report", help="Path where the test report should be written, e.g. --report=/path/to/report.md", ) + available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) + suite_help = ( + "Comma-separated integration test suites to narrow collection and prefill defaults. " + "Available: " + f"{available_suites}. " + "Explicit CLI flags (e.g., --text-model) override suite defaults. " + "Examples: --suite=responses or --suite=responses,vision." + ) + parser.addoption("--suite", help=suite_help) + MODEL_SHORT_IDS = { "meta-llama/Llama-3.2-3B-Instruct": "3B", @@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc): pytest_plugins = ["tests.integration.fixtures.common"] + + +def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: + """Skip collecting paths outside the selected suite roots for speed.""" + suites_raw = config.getoption("--suite") + if not suites_raw: + return False + + names = [p.strip() for p in str(suites_raw).split(",") if p.strip()] + roots: list[str] = [] + for name in names: + suite_def = SUITE_DEFINITIONS.get(name) + if suite_def: + roots.extend(suite_def.get("roots", [])) + if not roots: + return False + + p = Path(str(path)).resolve() + + # Only constrain within tests/integration to avoid ignoring unrelated tests + integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve() + if not p.is_relative_to(integration_root): + return False + + for r in roots: + rp = (Path(str(config.rootpath)) / r).resolve() + if rp.is_file(): + # Allow the exact file and any ancestor directories so pytest can walk into it. + if p == rp: + return False + if p.is_dir() and rp.is_relative_to(p): + return False + else: + # Allow anything inside an allowed directory + if p.is_relative_to(rp): + return False + return True diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 72137662d..099032578 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -5,6 +5,8 @@ # the root directory of this source tree. +import time + import pytest from ..test_cases.test_case import TestCase @@ -35,6 +37,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) "remote::sambanova", "remote::tgi", "remote::vertexai", + "remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404 ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -56,6 +59,18 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id): pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.") +def skip_if_doesnt_support_n(client_with_models, model_id): + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "remote::sambanova", + "remote::ollama", + # Error code: 400 - [{'error': {'code': 400, 'message': 'Only one candidate can be specified in the + # current model', 'status': 'INVALID_ARGUMENT'}}] + "remote::gemini", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") + + def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( @@ -260,10 +275,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex ) def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - - provider = provider_from_model(client_with_models, text_model_id) - if provider.provider_type == "remote::ollama": - pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.") + skip_if_doesnt_support_n(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] @@ -323,8 +335,15 @@ def test_inference_store(compat_client, client_with_models, text_model_id, strea response_id = response.id content = response.choices[0].message.content - responses = client.chat.completions.list(limit=1000) - assert response_id in [r.id for r in responses.data] + tries = 0 + while tries < 10: + responses = client.chat.completions.list(limit=1000) + if response_id in [r.id for r in responses.data]: + break + else: + tries += 1 + time.sleep(0.1) + assert tries < 10, f"Response {response_id} not found after 1 second" retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id @@ -388,6 +407,18 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode response_id = response.id content = response.choices[0].message.content + # wait for the response to be stored + tries = 0 + while tries < 10: + responses = client.chat.completions.list(limit=1000) + if response_id in [r.id for r in responses.data]: + break + else: + tries += 1 + time.sleep(0.1) + + assert tries < 10, f"Response {response_id} not found after 1 second" + responses = client.chat.completions.list(limit=1000) assert response_id in [r.id for r in responses.data] diff --git a/tests/integration/recordings/responses/41e27b9b5d09.json b/tests/integration/recordings/responses/41e27b9b5d09.json new file mode 100644 index 000000000..45d140843 --- /dev/null +++ b/tests/integration/recordings/responses/41e27b9b5d09.json @@ -0,0 +1,42 @@ +{ + "request": { + "method": "POST", + "url": "http://0.0.0.0:11434/v1/v1/completions", + "headers": {}, + "body": { + "model": "llama3.2:3b-instruct-fp16", + "prompt": "Say completions", + "max_tokens": 20 + }, + "endpoint": "/v1/completions", + "model": "llama3.2:3b-instruct-fp16" + }, + "response": { + "body": { + "__type__": "openai.types.completion.Completion", + "__data__": { + "id": "cmpl-271", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": "You want me to respond with a completion, but you didn't specify what I should complete. Could" + } + ], + "created": 1756846620, + "model": "llama3.2:3b-instruct-fp16", + "object": "text_completion", + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 28, + "total_tokens": 48, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + }, + "is_streaming": false + } +} diff --git a/tests/integration/non_ci/responses/__init__.py b/tests/integration/responses/__init__.py similarity index 100% rename from tests/integration/non_ci/responses/__init__.py rename to tests/integration/responses/__init__.py diff --git a/tests/integration/non_ci/responses/fixtures/__init__.py b/tests/integration/responses/fixtures/__init__.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/__init__.py rename to tests/integration/responses/fixtures/__init__.py diff --git a/tests/integration/non_ci/responses/fixtures/fixtures.py b/tests/integration/responses/fixtures/fixtures.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/fixtures.py rename to tests/integration/responses/fixtures/fixtures.py diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_1.jpg b/tests/integration/responses/fixtures/images/vision_test_1.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_1.jpg rename to tests/integration/responses/fixtures/images/vision_test_1.jpg diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_2.jpg b/tests/integration/responses/fixtures/images/vision_test_2.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_2.jpg rename to tests/integration/responses/fixtures/images/vision_test_2.jpg diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_3.jpg b/tests/integration/responses/fixtures/images/vision_test_3.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_3.jpg rename to tests/integration/responses/fixtures/images/vision_test_3.jpg diff --git a/tests/integration/non_ci/responses/fixtures/pdfs/llama_stack_and_models.pdf b/tests/integration/responses/fixtures/pdfs/llama_stack_and_models.pdf similarity index 100% rename from tests/integration/non_ci/responses/fixtures/pdfs/llama_stack_and_models.pdf rename to tests/integration/responses/fixtures/pdfs/llama_stack_and_models.pdf diff --git a/tests/integration/non_ci/responses/fixtures/test_cases.py b/tests/integration/responses/fixtures/test_cases.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/test_cases.py rename to tests/integration/responses/fixtures/test_cases.py diff --git a/tests/integration/non_ci/responses/helpers.py b/tests/integration/responses/helpers.py similarity index 100% rename from tests/integration/non_ci/responses/helpers.py rename to tests/integration/responses/helpers.py diff --git a/tests/integration/non_ci/responses/streaming_assertions.py b/tests/integration/responses/streaming_assertions.py similarity index 100% rename from tests/integration/non_ci/responses/streaming_assertions.py rename to tests/integration/responses/streaming_assertions.py diff --git a/tests/integration/non_ci/responses/test_basic_responses.py b/tests/integration/responses/test_basic_responses.py similarity index 100% rename from tests/integration/non_ci/responses/test_basic_responses.py rename to tests/integration/responses/test_basic_responses.py diff --git a/tests/integration/non_ci/responses/test_file_search.py b/tests/integration/responses/test_file_search.py similarity index 100% rename from tests/integration/non_ci/responses/test_file_search.py rename to tests/integration/responses/test_file_search.py diff --git a/tests/integration/non_ci/responses/test_tool_responses.py b/tests/integration/responses/test_tool_responses.py similarity index 100% rename from tests/integration/non_ci/responses/test_tool_responses.py rename to tests/integration/responses/test_tool_responses.py diff --git a/tests/integration/suites.py b/tests/integration/suites.py new file mode 100644 index 000000000..602855055 --- /dev/null +++ b/tests/integration/suites.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest. +# For example: +# +# ```bash +# pytest tests/integration/ --suite=vision +# ``` +# +# Each suite can: +# - restrict collection to specific roots (dirs or files) +# - provide default CLI option values (e.g. text_model, embedding_model, etc.) + +from pathlib import Path + +this_dir = Path(__file__).parent +default_roots = [ + str(p) + for p in this_dir.glob("*") + if p.is_dir() + and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training") +] + +SUITE_DEFINITIONS: dict[str, dict] = { + "base": { + "description": "Base suite that includes most tests but runs them with a text Ollama model", + "roots": default_roots, + "defaults": { + "text_model": "ollama/llama3.2:3b-instruct-fp16", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, + "responses": { + "description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model", + "roots": ["tests/integration/responses"], + "defaults": { + "text_model": "openai/gpt-4o", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, + "vision": { + "description": "Suite that includes only the vision tests", + "roots": ["tests/integration/inference/test_vision_inference.py"], + "defaults": { + "vision_model": "ollama/llama3.2-vision:11b", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, +} diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index 2affe2a2d..b208500d8 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models): client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id) clear_registry() + + try: + client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime") + except Exception: + pass + yield client_with_models - # you must clean after the last test if you were running tests against - # a stateful server instance clear_registry() @@ -66,12 +70,13 @@ def assert_valid_text_response(response): def test_vector_db_insert_inline_and_query( client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension ): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + vector_db = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + vector_db_id = vector_db.identifier client_with_empty_registry.tool_runtime.rag_tool.insert( documents=sample_documents, @@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query( # list to check memory bank is successfully registered available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query( client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) # Query for the name of method response1 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="What's the name of the fine-tunning method used?", ) assert_valid_chunk_response(response1) @@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query( # Query for the name of model response2 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="Which Llama model is mentioned?", ) assert_valid_chunk_response(response2) @@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i ) available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", ) assert_valid_text_response(response_with_metadata) assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "include_metadata_in_content": True, @@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i with pytest.raises((ValueError, BadRequestError)): client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "chunk_template": "This should raise a ValueError because it is missing the proper template variables", diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 07faa0db1..979eff6bb 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models): def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension): - # Register a memory bank first - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + # Retrieve the memory bank and validate its properties - response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id) + response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id) assert response is not None - assert response.identifier == vector_db_id + assert response.identifier == actual_vector_db_id assert response.embedding_model == embedding_model_id - assert response.provider_resource_id == vector_db_id + assert response.identifier.startswith("vs_") def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) - vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_dbs_after_register == [vector_db_id] + actual_vector_db_id = response.identifier + assert actual_vector_db_id.startswith("vs_") + assert actual_vector_db_id != vector_db_name - client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id) + vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + assert vector_dbs_after_register == [actual_vector_db_id] + + vector_stores = client_with_empty_registry.vector_stores.list() + assert len(vector_stores.data) == 1 + vector_store = vector_stores.data[0] + assert vector_store.id == actual_vector_db_id + assert vector_store.name == vector_db_name + + client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id) vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] assert len(vector_dbs) == 0 @@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe ], ) def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=sample_chunks, ) response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="What is the capital of France?", ) assert response is not None @@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding query, expected_doc_id = test_case response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query=query, ) assert response is not None @@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e "remote::qdrant": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0}, } - vector_db_id = "test_precomputed_embeddings_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_precomputed_embeddings_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + chunks_with_embeddings = [ Chunk( content="This is a test chunk with precomputed embedding.", @@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e ] client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=chunks_with_embeddings, ) provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="precomputed embedding test", params=vector_io_provider_params_dict.get(provider, None), ) @@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( "remote::qdrant": {"score_threshold": 0.0}, "inline::qdrant": {"score_threshold": 0.0}, } - vector_db_id = "test_precomputed_embeddings_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_precomputed_embeddings_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + chunks_with_embeddings = [ Chunk( content="duplicate", @@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( ] client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=chunks_with_embeddings, ) provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="duplicate", params=vector_io_provider_params_dict.get(provider, None), ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 2652f5c8d..1ceee81c6 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -146,6 +146,20 @@ class VectorDBImpl(Impl): async def unregister_vector_db(self, vector_db_id: str): return vector_db_id + async def openai_create_vector_store(self, **kwargs): + import time + import uuid + + from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject + + vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}" + return VectorStoreObject( + id=vector_store_id, + name=kwargs.get("name", vector_store_id), + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + async def test_models_routing_table(cached_disk_dist_registry): table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) @@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): ) # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") + vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") + vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 2 vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids + assert vdb1.identifier in vector_db_ids + assert vdb2.identifier in vector_db_ids - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") + # Verify they have UUID-based identifiers + assert vdb1.identifier.startswith("vs_") + assert vdb2.identifier.startswith("vs_") + + await table.unregister_vector_db(vector_db_id=vdb1.identifier) + await table.unregister_vector_db(vector_db_id=vdb2.identifier) vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 0 diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py index 789eda433..3444f64c2 100644 --- a/tests/unit/distribution/routing_tables/test_vector_dbs.py +++ b/tests/unit/distribution/routing_tables/test_vector_dbs.py @@ -7,6 +7,7 @@ # Unit tests for the routing tables vector_dbs import time +import uuid from unittest.mock import AsyncMock import pytest @@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI class VectorDBImpl(Impl): def __init__(self): super().__init__(Api.vector_io) + self.vector_stores = {} async def register_vector_db(self, vector_db: VectorDB): return vector_db @@ -114,8 +116,35 @@ class VectorDBImpl(Impl): async def openai_delete_vector_store_file(self, vector_store_id, file_id): return VectorStoreFileDeleteResponse(id=file_id, deleted=True) + async def openai_create_vector_store( + self, + name=None, + embedding_model=None, + embedding_dimension=None, + provider_id=None, + provider_vector_db_id=None, + **kwargs, + ): + vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}" + vector_store = VectorStoreObject( + id=vector_store_id, + name=name or vector_store_id, + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + self.vector_stores[vector_store_id] = vector_store + return vector_store + + async def openai_list_vector_stores(self, **kwargs): + from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse + + return VectorStoreListResponse( + data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None + ) + async def test_vectordbs_routing_table(cached_disk_dist_registry): + n = 10 table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): ) # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") + vdb_dict = {} + for i in range(n): + vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") + vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 2 + assert len(vector_dbs.data) == len(vdb_dict) vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids - - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") + for k in vdb_dict: + assert vdb_dict[k].identifier in vector_db_ids + for k in vdb_dict: + await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier) vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 0 +async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry): + n = 10 + impl = VectorDBImpl() + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + vdb_dict = {} + for i in range(n): + vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") + + vector_dbs = await table.list_vector_dbs() + vector_db_ids = {v.identifier for v in vector_dbs.data} + + vector_stores = await impl.openai_list_vector_stores() + vector_store_ids = {v.id for v in vector_stores.data} + + assert vector_db_ids == vector_store_ids, ( + f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}" + ) + + for vector_store in vector_stores.data: + vector_db = await table.get_vector_db(vector_store.id) + assert vector_store.name == vector_db.vector_db_name, ( + f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}" + ) + + for vector_db_id in vector_db_ids: + await table.unregister_vector_db(vector_db_id) + + assert len((await table.list_vector_dbs()).data) == 0 + + +async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry): + impl = VectorDBImpl() + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + user_provided_id = "my-custom-vector-db" + await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model") + + vector_stores = await impl.openai_list_vector_stores() + assert len(vector_stores.data) == 1 + + vector_store = vector_stores.data[0] + + assert vector_store.name == user_provided_id + + assert vector_store.id.startswith("vs_") + assert vector_store.id != user_provided_id + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 1 + assert vector_dbs.data[0].identifier == vector_store.id + + await table.unregister_vector_db(vector_store.id) + + async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): impl = VectorDBImpl() impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") @@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) with request_provider_data_context({}, authorized_user): - _ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + authorized_table = registered_vdb.identifier # Use the actual generated ID # Authorized reader with request_provider_data_context({}, authorized_user): @@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis ) with request_provider_data_context({}, admin_user): - await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + vector_db_id = registered_vdb.identifier # Use the actual generated ID read_methods = [ (table.openai_retrieve_vector_store, (vector_db_id,), {}), diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index 0ca866f7b..dfef5e040 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated: * test_validate_input_url_mismatch (negative) * test_validate_input_multiple_errors_per_request (negative) * test_validate_input_invalid_request_format (negative) - * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) + * test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions) + * test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions) * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) The tests use temporary SQLite databases for isolation and mock external @@ -213,7 +214,6 @@ class TestReferenceBatchesImpl: "endpoint", [ "/v1/embeddings", - "/v1/completions", "/v1/invalid/endpoint", "", ], @@ -499,8 +499,10 @@ class TestReferenceBatchesImpl: ("messages", "body.messages", "invalid_request", "Messages parameter is required"), ], ) - async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): - """Test _validate_input when file contains request with missing required parameters.""" + async def test_validate_input_missing_parameters_chat_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for chat completions.""" provider.files_api.openai_retrieve_file = AsyncMock() mock_response = MagicMock() @@ -541,6 +543,61 @@ class TestReferenceBatchesImpl: assert errors[0].message == error_message assert errors[0].param == param_path + @pytest.mark.parametrize( + "param_name,param_path,error_code,error_message", + [ + ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), + ("method", "method", "missing_required_parameter", "Missing required parameter: method"), + ("url", "url", "missing_required_parameter", "Missing required parameter: url"), + ("body", "body", "missing_required_parameter", "Missing required parameter: body"), + ("model", "body.model", "invalid_request", "Model parameter is required"), + ("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"), + ], + ) + async def test_validate_input_missing_parameters_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for text completions.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/completions", + "body": {"model": "test-model", "prompt": "Hello"}, + } + + # Remove the specific parameter being tested + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + del base_request[top_level][nested_param] + else: + del base_request[param_name] + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/completions", + input_file_id=f"missing_{param_name}_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == error_code + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + async def test_validate_input_url_mismatch(self, provider): """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" provider.files_api.openai_retrieve_file = AsyncMock() diff --git a/tests/unit/providers/inference/bedrock/test_config.py b/tests/unit/providers/inference/bedrock/test_config.py new file mode 100644 index 000000000..1b8639f2e --- /dev/null +++ b/tests/unit/providers/inference/bedrock/test_config.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from unittest.mock import patch + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig + + +class TestBedrockBaseConfig: + def test_defaults_work_without_env_vars(self): + with patch.dict(os.environ, {}, clear=True): + config = BedrockBaseConfig() + + # Basic creds should be None + assert config.aws_access_key_id is None + assert config.aws_secret_access_key is None + assert config.region_name is None + + # Timeouts get defaults + assert config.connect_timeout == 60.0 + assert config.read_timeout == 60.0 + assert config.session_ttl == 3600 + + def test_env_vars_get_picked_up(self): + env_vars = { + "AWS_ACCESS_KEY_ID": "AKIATEST123", + "AWS_SECRET_ACCESS_KEY": "secret123", + "AWS_DEFAULT_REGION": "us-west-2", + "AWS_MAX_ATTEMPTS": "5", + "AWS_RETRY_MODE": "adaptive", + "AWS_CONNECT_TIMEOUT": "30", + } + + with patch.dict(os.environ, env_vars, clear=True): + config = BedrockBaseConfig() + + assert config.aws_access_key_id == "AKIATEST123" + assert config.aws_secret_access_key == "secret123" + assert config.region_name == "us-west-2" + assert config.total_max_attempts == 5 + assert config.retry_mode == "adaptive" + assert config.connect_timeout == 30.0 + + def test_partial_env_setup(self): + # Just setting one timeout var + with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True): + config = BedrockBaseConfig() + + assert config.connect_timeout == 120.0 + assert config.read_timeout == 60.0 # still default + assert config.aws_access_key_id is None + + def test_bad_max_attempts_breaks(self): + with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True): + try: + BedrockBaseConfig() + raise AssertionError("Should have failed on bad int conversion") + except ValueError: + pass # expected diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index 05ccecb99..d18d90716 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti class TestRagQuery: async def test_query_raises_on_empty_vector_db_ids(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) with pytest.raises(ValueError): await rag_tool.query(content=MagicMock(), vector_db_ids=[]) async def test_query_chunk_metadata_handling(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) content = "test query content" vector_db_ids = ["db1"] diff --git a/tests/unit/server/test_server.py b/tests/unit/server/test_server.py index 803111fc7..f21bbdd67 100644 --- a/tests/unit/server/test_server.py +++ b/tests/unit/server/test_server.py @@ -113,6 +113,15 @@ class TestTranslateException: assert result.status_code == 504 assert result.detail == "Operation timed out: " + def test_translate_connection_error(self): + """Test that ConnectionError is translated to 502 HTTP status.""" + exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 502 + assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused" + def test_translate_not_implemented_error(self): """Test that NotImplementedError is translated to 501 HTTP status.""" exc = NotImplementedError("Not implemented")