From 64d2306dd5481be3fdfaa11c415efe3f2788358e Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Thu, 4 Sep 2025 16:56:32 +0100 Subject: [PATCH 01/17] fix: distro-codegen pre-commit hook file pattern (#3337) Update the file pattern from 'llama_stack/templates' to 'llama_stack/distributions' to properly trigger the Distribution Template Codegen hook when distribution files change. Signed-off-by: Derek Higgins --- .pre-commit-config.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 514fe6d2e..b7880a9fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -86,7 +86,7 @@ repos: language: python pass_filenames: false require_serial: true - files: ^llama_stack/templates/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ + files: ^llama_stack/distributions/.*$|^llama_stack/providers/.*/inference/.*/models\.py$ - id: provider-codegen name: Provider Codegen additional_dependencies: From 85f33762d7cb7495b9aa5946284e4b6c85f5ad96 Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Thu, 4 Sep 2025 17:15:13 +0100 Subject: [PATCH 02/17] refactor(server): remove hardcoded 409 and 404 status codes in server.py using httpx constants (#3333) # What does this PR do? This PR is eliminating hardcoded status codes: `409` CONFLICT and `404` NOT_FOUND in `server.py` using `httpx` built-in constants. This implementation will follow the existing structure to improve readability, extensibility and developer experience. This is already was implemented in #3131 ## Test Plan `./scripts/unit-tests.sh` --- llama_stack/core/server/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index d6dfc3435..b247a610d 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -132,9 +132,9 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro }, ) elif isinstance(exc, ConflictError): - return HTTPException(status_code=409, detail=str(exc)) + return HTTPException(status_code=httpx.codes.CONFLICT, detail=str(exc)) elif isinstance(exc, ResourceNotFoundError): - return HTTPException(status_code=404, detail=str(exc)) + return HTTPException(status_code=httpx.codes.NOT_FOUND, detail=str(exc)) elif isinstance(exc, ValueError): return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=f"Invalid value: {str(exc)}") elif isinstance(exc, BadRequestError): From 5bbca56cfc24cd1a1d4d5aff6a9c1c4ad12a741a Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Thu, 4 Sep 2025 18:58:41 +0100 Subject: [PATCH 03/17] fix: Make SentenceTransformer embedding operations non-blocking (#3335) - Wrap model loading with asyncio.to_thread() to prevent blocking during model download/initialization - Wrap encoding operations with asyncio.to_thread() to run in background thread - Convert _load_sentence_transformer_model() to async method This ensures the async event loop remains responsive during embedding operations. Closes: #3332 Signed-off-by: Derek Higgins Co-authored-by: Francisco Arceo --- .../utils/inference/embedding_mixin.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 65ba2854b..9bd0aa8ce 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import base64 import struct from typing import TYPE_CHECKING @@ -43,9 +44,11 @@ class SentenceTransformerEmbeddingMixin: task_type: EmbeddingTaskType | None = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) - embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode( - [interleaved_content_as_str(content) for content in contents], show_progress_bar=False + embedding_model = await self._load_sentence_transformer_model(model.provider_resource_id) + embeddings = await asyncio.to_thread( + embedding_model.encode, + [interleaved_content_as_str(content) for content in contents], + show_progress_bar=False, ) return EmbeddingsResponse(embeddings=embeddings) @@ -64,8 +67,8 @@ class SentenceTransformerEmbeddingMixin: # Get the model and generate embeddings model_obj = await self.model_store.get_model(model) - embedding_model = self._load_sentence_transformer_model(model_obj.provider_resource_id) - embeddings = embedding_model.encode(input_list, show_progress_bar=False) + embedding_model = await self._load_sentence_transformer_model(model_obj.provider_resource_id) + embeddings = await asyncio.to_thread(embedding_model.encode, input_list, show_progress_bar=False) # Convert embeddings to the requested format data = [] @@ -93,7 +96,7 @@ class SentenceTransformerEmbeddingMixin: usage=usage, ) - def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": + async def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": global EMBEDDING_MODELS loaded_model = EMBEDDING_MODELS.get(model) @@ -101,8 +104,12 @@ class SentenceTransformerEmbeddingMixin: return loaded_model log.info(f"Loading sentence transformer for {model}...") - from sentence_transformers import SentenceTransformer - loaded_model = SentenceTransformer(model) + def _load_model(): + from sentence_transformers import SentenceTransformer + + return SentenceTransformer(model) + + loaded_model = await asyncio.to_thread(_load_model) EMBEDDING_MODELS[model] = loaded_model return loaded_model From bcc7f2c7d0a24a4671e98134e6c28f74e6acff26 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 4 Sep 2025 11:37:46 -0700 Subject: [PATCH 04/17] chore: async inference store write (#3318) # What does this PR do? ## Test Plan ``` cd /docs/source/distributions/k8s-benchmark # start mock server python openai-mock-server.py --port 8000 # start stack server uv run --with llama-stack python -m llama_stack.core.server.server docs/source/distributions/k8s-benchmark/stack_run_config.yaml # run benchmark script uv run python3 benchmark.py --duration 30 --concurrent 50 --base-url=http://localhost:8321/v1/openai/v1 --model=vllm-inference/meta-llama/Llama-3.2-3B-Instruct ``` Before: ============================================================ BENCHMARK RESULTS ============================================================ Total time: 30.00s Concurrent users: 50 Total requests: 1267 Successful requests: 1267 Failed requests: 0 Success rate: 100.0% Requests per second: 42.23 After: ============================================================ BENCHMARK RESULTS ============================================================ Total time: 30.00s Concurrent users: 50 Total requests: 1449 Successful requests: 1449 Failed requests: 0 Success rate: 100.0% Requests per second: 48.30 --- .../distributions/k8s-benchmark/stack_run_config.yaml | 8 ++++++++ llama_stack/core/routers/inference.py | 4 ++-- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml index ceb1ba2d9..5a810639e 100644 --- a/docs/source/distributions/k8s-benchmark/stack_run_config.yaml +++ b/docs/source/distributions/k8s-benchmark/stack_run_config.yaml @@ -3,6 +3,7 @@ image_name: kubernetes-benchmark-demo apis: - agents - inference +- safety - telemetry - tool_runtime - vector_io @@ -30,6 +31,11 @@ providers: db: ${env.POSTGRES_DB:=llamastack} user: ${env.POSTGRES_USER:=llamastack} password: ${env.POSTGRES_PASSWORD:=llamastack} + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] agents: - provider_id: meta-reference provider_type: inline::meta-reference @@ -95,6 +101,8 @@ models: - model_id: ${env.INFERENCE_MODEL} provider_id: vllm-inference model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL:=meta-llama/Llama-Guard-3-1B} vector_dbs: [] datasets: [] scoring_fns: [] diff --git a/llama_stack/core/routers/inference.py b/llama_stack/core/routers/inference.py index 8dcad85e3..045093fe0 100644 --- a/llama_stack/core/routers/inference.py +++ b/llama_stack/core/routers/inference.py @@ -527,7 +527,7 @@ class InferenceRouter(Inference): # Store the response with the ID that will be returned to the client if self.store: - await self.store.store_chat_completion(response, messages) + asyncio.create_task(self.store.store_chat_completion(response, messages)) if self.telemetry: metrics = self._construct_metrics( @@ -855,4 +855,4 @@ class InferenceRouter(Inference): object="chat.completion", ) logger.debug(f"InferenceRouter.completion_response: {final_response}") - await self.store.store_chat_completion(final_response, messages) + asyncio.create_task(self.store.store_chat_completion(final_response, messages)) From 561d2fc6b8226f167bb9782e6619d86a092af8e8 Mon Sep 17 00:00:00 2001 From: slekkala1 Date: Thu, 4 Sep 2025 11:47:46 -0700 Subject: [PATCH 05/17] fix: Move to older version for docker container failure [fireworks-ai] (#3338) # What does this PR do? Noticed the test https://github.com/llamastack/llama-stack-ops/actions/workflows/test-maybe-cut.yaml are still failing randomly. Earlier fixed this with 0.18.0 of fireworks here https://github.com/llamastack/llama-stack/pull/3267, the local testing may have inadvertently picked a lower version with `<=` which I assumed picks latest version. Now tested with `==` to find the version where it broke and pinning to version(`<=`) where it was passing. ## Test Plan Tested locally with the following commands to start a container Build container `llama stack build --distro starter --image-type container` start container `docker run -d -p 8321:8321 --name llama-stack-test distribution-starter:0.2.20` check health `http://localhost:8321/v1/health` Above steps fails without the fix Tested with `==` to ensure the same version is picked in local testing instead of anything lower. Following here for the fix from `fireworks-ai` https://discord.com/channels/1137072072808472616/1410674695597981778/1410674695597981778 https://github.com/llamastack/llama-stack/issues/3273 --- llama_stack/providers/registry/inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index fb841afdf..50956f58c 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -116,7 +116,7 @@ def available_providers() -> list[ProviderSpec]: adapter=AdapterSpec( adapter_type="fireworks", pip_packages=[ - "fireworks-ai<=0.18.0", + "fireworks-ai<=0.17.16", ], module="llama_stack.providers.remote.inference.fireworks", config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", From 55a8c5f439b710cfbb1256e3d6c41a6d82601970 Mon Sep 17 00:00:00 2001 From: Sumanth Kamenani Date: Thu, 4 Sep 2025 16:25:02 -0400 Subject: [PATCH 06/17] fix: show descriptive MCP server connection errors instead of generic 500s (#3256) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit What does this PR do? Fixes error handling when MCP server connections fail. Instead of returning generic 500 errors, now provides descriptive error messages with proper HTTP status codes. Closes #3107 Test Plan Before fix: curl -X GET "http://localhost:8321/v1/tool-runtime/list-tools?tool_group_id=bad-mcp-server" Returns: {"detail": "Internal server error: An unexpected error occurred."} (500) After fix: curl -X GET "http://localhost:8321/v1/tool-runtime/list-tools?tool_group_id=bad-mcp-server" Returns: {"error": {"detail": "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused"}} (502) Tests: - Added unit test for ConnectionError → 502 translation - Manually tested with unreachable MCP servers (connection refused) --- llama_stack/core/server/server.py | 2 ++ llama_stack/providers/utils/tools/mcp.py | 32 ++++++++++++++++++++++++ tests/unit/server/test_server.py | 9 +++++++ 3 files changed, 43 insertions(+) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index b247a610d..288bf46e1 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -141,6 +141,8 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro return HTTPException(status_code=httpx.codes.BAD_REQUEST, detail=str(exc)) elif isinstance(exc, PermissionError | AccessDeniedError): return HTTPException(status_code=httpx.codes.FORBIDDEN, detail=f"Permission denied: {str(exc)}") + elif isinstance(exc, ConnectionError | httpx.ConnectError): + return HTTPException(status_code=httpx.codes.BAD_GATEWAY, detail=str(exc)) elif isinstance(exc, asyncio.TimeoutError | TimeoutError): return HTTPException(status_code=httpx.codes.GATEWAY_TIMEOUT, detail=f"Operation timed out: {str(exc)}") elif isinstance(exc, NotImplementedError): diff --git a/llama_stack/providers/utils/tools/mcp.py b/llama_stack/providers/utils/tools/mcp.py index 02f7aaf8a..fc8e2f377 100644 --- a/llama_stack/providers/utils/tools/mcp.py +++ b/llama_stack/providers/utils/tools/mcp.py @@ -67,6 +67,38 @@ async def client_wrapper(endpoint: str, headers: dict[str, str]) -> AsyncGenerat raise AuthenticationRequiredError(exc) from exc if i == len(connection_strategies) - 1: raise + except* httpx.ConnectError as eg: + # Connection refused, server down, network unreachable + if i == len(connection_strategies) - 1: + error_msg = f"Failed to connect to MCP server at {endpoint}: Connection refused" + logger.error(f"MCP connection error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"failed to connect to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.TimeoutException as eg: + # Request timeout, server too slow + if i == len(connection_strategies) - 1: + error_msg = f"MCP server at {endpoint} timed out" + logger.error(f"MCP timeout error: {error_msg}") + raise TimeoutError(error_msg) from eg + else: + logger.warning( + f"MCP server at {endpoint} timed out via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) + except* httpx.RequestError as eg: + # DNS resolution failures, network errors, invalid URLs + if i == len(connection_strategies) - 1: + # Get the first exception's message for the error string + exc_msg = str(eg.exceptions[0]) if eg.exceptions else "Unknown error" + error_msg = f"Network error connecting to MCP server at {endpoint}: {exc_msg}" + logger.error(f"MCP network error: {error_msg}") + raise ConnectionError(error_msg) from eg + else: + logger.warning( + f"network error connecting to MCP server at {endpoint} via {strategy.name}, falling back to {connection_strategies[i + 1].name}" + ) except* McpError: if i < len(connection_strategies) - 1: logger.warning( diff --git a/tests/unit/server/test_server.py b/tests/unit/server/test_server.py index 803111fc7..f21bbdd67 100644 --- a/tests/unit/server/test_server.py +++ b/tests/unit/server/test_server.py @@ -113,6 +113,15 @@ class TestTranslateException: assert result.status_code == 504 assert result.detail == "Operation timed out: " + def test_translate_connection_error(self): + """Test that ConnectionError is translated to 502 HTTP status.""" + exc = ConnectionError("Failed to connect to MCP server at http://localhost:9999/sse: Connection refused") + result = translate_exception(exc) + + assert isinstance(result, HTTPException) + assert result.status_code == 502 + assert result.detail == "Failed to connect to MCP server at http://localhost:9999/sse: Connection refused" + def test_translate_not_implemented_error(self): """Test that NotImplementedError is translated to 501 HTTP status.""" exc = NotImplementedError("Not implemented") From 3a7ac4227d4a2c0a74e3a7e5eda459ae19761c03 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 4 Sep 2025 15:13:31 -0700 Subject: [PATCH 07/17] chore: unbreak inference store test (#3340) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The inference store writes were moved to asyncio.create_task and not await anymore ## Test Plan ❯ OLLAMA_URL=http://localhost:11434 LLAMA_STACK_CONFIG=server:starter uv run --with pytest-repeat pytest tests/integration/inference --text-model="ollama/llama3.2:3b-instruct-fp16" -vvs -k "test_inference_store_tool_calls and 3b-instruct-fp16-True" --count=10 Uninstalled 2 packages in 102ms Installed 2 packages in 138ms INFO 2025-09-04 14:10:17,775 tests.integration.conftest:66 tests: Setting DISABLE_CODE_SANDBOX=1 for macOS ========================================================================================================== test session starts =========================================================================================================== platform darwin -- Python 3.12.3, pytest-8.4.1, pluggy-1.6.0 -- /Users/erichuang/.cache/uv/builds-v0/.tmpSGMlgt/bin/python cachedir: .pytest_cache metadata: {'Python': '3.12.3', 'Platform': 'macOS-15.6.1-arm64-arm-64bit', 'Packages': {'pytest': '8.4.1', 'pluggy': '1.6.0'}, 'Plugins': {'repeat': '0.9.4', 'anyio': '4.9.0', 'html': '4.1.1', 'socket': '0.7.0', 'asyncio': '1.1.0', 'json-report': '1.5.0', 'timeout': '2.4.0', 'metadata': '3.1.1', 'cov': '6.2.1', 'nbval': '0.11.0'}} rootdir: /Users/erichuang/projects/llama-stack-git configfile: pyproject.toml plugins: repeat-0.9.4, anyio-4.9.0, html-4.1.1, socket-0.7.0, asyncio-1.1.0, json-report-1.5.0, timeout-2.4.0, metadata-3.1.1, cov-6.2.1, nbval-0.11.0 asyncio: mode=Mode.AUTO, asyncio_default_fixture_loop_scope=None, asyncio_default_test_loop_scope=function collected 970 items / 950 deselected / 20 selected tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-1-10] instantiating llama_stack_client Starting llama stack server with config 'starter' on port 8321... Waiting for server at http://localhost:8321... (0.0s elapsed) Waiting for server at http://localhost:8321... (0.5s elapsed) Waiting for server at http://localhost:8321... (5.1s elapsed) Waiting for server at http://localhost:8321... (5.6s elapsed) Waiting for server at http://localhost:8321... (10.1s elapsed) Waiting for server at http://localhost:8321... (10.6s elapsed) Waiting for server at http://localhost:8321... (15.2s elapsed) Waiting for server at http://localhost:8321... (15.7s elapsed) Server is ready at http://localhost:8321 llama_stack_client instantiated in 20.583s PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-2-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-3-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-4-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-5-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-6-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-7-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-8-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-9-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[openai_client-txt=ollama/llama3.2:3b-instruct-fp16-True-10-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-1-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-2-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-3-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-4-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-5-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-6-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-7-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-8-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-9-10] PASSED tests/integration/inference/test_openai_completion.py::test_inference_store_tool_calls[client_with_models-txt=ollama/llama3.2:3b-instruct-fp16-True-10-10] PASSEDTerminating llama stack server process... Terminating process 53307 and its group... Server process and children terminated gracefully --- .../inference/test_openai_completion.py | 25 +++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 72137662d..62185e470 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -5,6 +5,8 @@ # the root directory of this source tree. +import time + import pytest from ..test_cases.test_case import TestCase @@ -323,8 +325,15 @@ def test_inference_store(compat_client, client_with_models, text_model_id, strea response_id = response.id content = response.choices[0].message.content - responses = client.chat.completions.list(limit=1000) - assert response_id in [r.id for r in responses.data] + tries = 0 + while tries < 10: + responses = client.chat.completions.list(limit=1000) + if response_id in [r.id for r in responses.data]: + break + else: + tries += 1 + time.sleep(0.1) + assert tries < 10, f"Response {response_id} not found after 1 second" retrieved_response = client.chat.completions.retrieve(response_id) assert retrieved_response.id == response_id @@ -388,6 +397,18 @@ def test_inference_store_tool_calls(compat_client, client_with_models, text_mode response_id = response.id content = response.choices[0].message.content + # wait for the response to be stored + tries = 0 + while tries < 10: + responses = client.chat.completions.list(limit=1000) + if response_id in [r.id for r in responses.data]: + break + else: + tries += 1 + time.sleep(0.1) + + assert tries < 10, f"Response {response_id} not found after 1 second" + responses = client.chat.completions.list(limit=1000) assert response_id in [r.id for r in responses.data] From 0b00c68d59d28ea4a502ccb54708bd384c4dea73 Mon Sep 17 00:00:00 2001 From: Sumanth Kamenani Date: Fri, 5 Sep 2025 04:45:11 -0400 Subject: [PATCH 08/17] fix: use lambda pattern for bedrock config env vars (#3307) # What does this PR do? Improved bedrock provider config to read from environment variables like AWS_ACCESS_KEY_ID. Updated all fields to use default_factory with lambda patterns like the nvidia provider does. Now the environment variables work as documented. Closes #3305 ## Test Plan Ran the new bedrock config tests: ```bash python -m pytest tests/unit/providers/inference/bedrock/test_config.py -v Verified existing provider tests still work: python -m pytest tests/unit/providers/test_configs.py -v --- .../providers/inference/remote_bedrock.md | 4 +- .../source/providers/safety/remote_bedrock.md | 4 +- llama_stack/providers/utils/bedrock/config.py | 22 ++++--- .../inference/bedrock/test_config.py | 63 +++++++++++++++++++ 4 files changed, 79 insertions(+), 14 deletions(-) create mode 100644 tests/unit/providers/inference/bedrock/test_config.py diff --git a/docs/source/providers/inference/remote_bedrock.md b/docs/source/providers/inference/remote_bedrock.md index 1454c54c2..216dd4adb 100644 --- a/docs/source/providers/inference/remote_bedrock.md +++ b/docs/source/providers/inference/remote_bedrock.md @@ -15,8 +15,8 @@ AWS Bedrock inference provider for accessing various AI models through AWS's man | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | -| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | -| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | +| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | +| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | ## Sample Configuration diff --git a/docs/source/providers/safety/remote_bedrock.md b/docs/source/providers/safety/remote_bedrock.md index 3c1d6bcb0..99d77dd72 100644 --- a/docs/source/providers/safety/remote_bedrock.md +++ b/docs/source/providers/safety/remote_bedrock.md @@ -15,8 +15,8 @@ AWS Bedrock safety provider for content moderation using AWS's safety services. | `profile_name` | `str \| None` | No | | The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE | | `total_max_attempts` | `int \| None` | No | | An integer representing the maximum number of attempts that will be made for a single request, including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS | | `retry_mode` | `str \| None` | No | | A string representing the type of retries Boto3 will perform.Default use environment variable: AWS_RETRY_MODE | -| `connect_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | -| `read_timeout` | `float \| None` | No | 60 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | +| `connect_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to make a connection. The default is 60 seconds. | +| `read_timeout` | `float \| None` | No | 60.0 | The time in seconds till a timeout exception is thrown when attempting to read from a connection.The default is 60 seconds. | | `session_ttl` | `int \| None` | No | 3600 | The time in seconds till a session expires. The default is 3600 seconds (1 hour). | ## Sample Configuration diff --git a/llama_stack/providers/utils/bedrock/config.py b/llama_stack/providers/utils/bedrock/config.py index b25617d76..2745c88cb 100644 --- a/llama_stack/providers/utils/bedrock/config.py +++ b/llama_stack/providers/utils/bedrock/config.py @@ -4,53 +4,55 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import os + from pydantic import BaseModel, Field class BedrockBaseConfig(BaseModel): aws_access_key_id: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_ACCESS_KEY_ID"), description="The AWS access key to use. Default use environment variable: AWS_ACCESS_KEY_ID", ) aws_secret_access_key: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_SECRET_ACCESS_KEY"), description="The AWS secret access key to use. Default use environment variable: AWS_SECRET_ACCESS_KEY", ) aws_session_token: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_SESSION_TOKEN"), description="The AWS session token to use. Default use environment variable: AWS_SESSION_TOKEN", ) region_name: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_DEFAULT_REGION"), description="The default AWS Region to use, for example, us-west-1 or us-west-2." "Default use environment variable: AWS_DEFAULT_REGION", ) profile_name: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_PROFILE"), description="The profile name that contains credentials to use.Default use environment variable: AWS_PROFILE", ) total_max_attempts: int | None = Field( - default=None, + default_factory=lambda: int(val) if (val := os.getenv("AWS_MAX_ATTEMPTS")) else None, description="An integer representing the maximum number of attempts that will be made for a single request, " "including the initial attempt. Default use environment variable: AWS_MAX_ATTEMPTS", ) retry_mode: str | None = Field( - default=None, + default_factory=lambda: os.getenv("AWS_RETRY_MODE"), description="A string representing the type of retries Boto3 will perform." "Default use environment variable: AWS_RETRY_MODE", ) connect_timeout: float | None = Field( - default=60, + default_factory=lambda: float(os.getenv("AWS_CONNECT_TIMEOUT", "60")), description="The time in seconds till a timeout exception is thrown when attempting to make a connection. " "The default is 60 seconds.", ) read_timeout: float | None = Field( - default=60, + default_factory=lambda: float(os.getenv("AWS_READ_TIMEOUT", "60")), description="The time in seconds till a timeout exception is thrown when attempting to read from a connection." "The default is 60 seconds.", ) session_ttl: int | None = Field( - default=3600, + default_factory=lambda: int(os.getenv("AWS_SESSION_TTL", "3600")), description="The time in seconds till a session expires. The default is 3600 seconds (1 hour).", ) diff --git a/tests/unit/providers/inference/bedrock/test_config.py b/tests/unit/providers/inference/bedrock/test_config.py new file mode 100644 index 000000000..1b8639f2e --- /dev/null +++ b/tests/unit/providers/inference/bedrock/test_config.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from unittest.mock import patch + +from llama_stack.providers.utils.bedrock.config import BedrockBaseConfig + + +class TestBedrockBaseConfig: + def test_defaults_work_without_env_vars(self): + with patch.dict(os.environ, {}, clear=True): + config = BedrockBaseConfig() + + # Basic creds should be None + assert config.aws_access_key_id is None + assert config.aws_secret_access_key is None + assert config.region_name is None + + # Timeouts get defaults + assert config.connect_timeout == 60.0 + assert config.read_timeout == 60.0 + assert config.session_ttl == 3600 + + def test_env_vars_get_picked_up(self): + env_vars = { + "AWS_ACCESS_KEY_ID": "AKIATEST123", + "AWS_SECRET_ACCESS_KEY": "secret123", + "AWS_DEFAULT_REGION": "us-west-2", + "AWS_MAX_ATTEMPTS": "5", + "AWS_RETRY_MODE": "adaptive", + "AWS_CONNECT_TIMEOUT": "30", + } + + with patch.dict(os.environ, env_vars, clear=True): + config = BedrockBaseConfig() + + assert config.aws_access_key_id == "AKIATEST123" + assert config.aws_secret_access_key == "secret123" + assert config.region_name == "us-west-2" + assert config.total_max_attempts == 5 + assert config.retry_mode == "adaptive" + assert config.connect_timeout == 30.0 + + def test_partial_env_setup(self): + # Just setting one timeout var + with patch.dict(os.environ, {"AWS_CONNECT_TIMEOUT": "120"}, clear=True): + config = BedrockBaseConfig() + + assert config.connect_timeout == 120.0 + assert config.read_timeout == 60.0 # still default + assert config.aws_access_key_id is None + + def test_bad_max_attempts_breaks(self): + with patch.dict(os.environ, {"AWS_MAX_ATTEMPTS": "not_a_number"}, clear=True): + try: + BedrockBaseConfig() + raise AssertionError("Should have failed on bad int conversion") + except ValueError: + pass # expected From 64b297716247317a0e0714b61fc7ee77bbfb9e52 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Fri, 5 Sep 2025 13:09:36 +0100 Subject: [PATCH 09/17] fix: Fix locations of distrubution runtime directories (#3336) The defaults were mixed up Signed-off-by: Derek Higgins --- llama_stack/distributions/ci-tests/ci_tests.py | 4 +--- llama_stack/distributions/ci-tests/run.yaml | 18 +++++++++--------- llama_stack/distributions/starter-gpu/run.yaml | 18 +++++++++--------- .../distributions/starter-gpu/starter_gpu.py | 4 +--- llama_stack/distributions/starter/starter.py | 3 +-- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/llama_stack/distributions/ci-tests/ci_tests.py b/llama_stack/distributions/ci-tests/ci_tests.py index 8fb61faca..ab102f5f3 100644 --- a/llama_stack/distributions/ci-tests/ci_tests.py +++ b/llama_stack/distributions/ci-tests/ci_tests.py @@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut def get_distribution_template() -> DistributionTemplate: - template = get_starter_distribution_template() - name = "ci-tests" - template.name = name + template = get_starter_distribution_template(name="ci-tests") template.description = "CI tests for Llama Stack" return template diff --git a/llama_stack/distributions/ci-tests/run.yaml b/llama_stack/distributions/ci-tests/run.yaml index 7523df581..26a677c7a 100644 --- a/llama_stack/distributions/ci-tests/run.yaml +++ b/llama_stack/distributions/ci-tests/run.yaml @@ -89,28 +89,28 @@ providers: config: kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/faiss_store.db - provider_id: sqlite-vec provider_type: inline::sqlite-vec config: - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/sqlite_vec_registry.db - provider_id: ${env.MILVUS_URL:+milvus} provider_type: inline::milvus config: - db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/ci-tests}/milvus.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/milvus_registry.db - provider_id: ${env.CHROMADB_URL:+chromadb} provider_type: remote::chromadb config: url: ${env.CHROMADB_URL:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests/}/chroma_remote_registry.db - provider_id: ${env.PGVECTOR_DB:+pgvector} provider_type: remote::pgvector config: @@ -121,15 +121,15 @@ providers: password: ${env.PGVECTOR_PASSWORD:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/pgvector_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs config: - storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/ci-tests/files} metadata_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/files_metadata.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/starter-gpu/run.yaml b/llama_stack/distributions/starter-gpu/run.yaml index 8aed61519..5d9dfcb27 100644 --- a/llama_stack/distributions/starter-gpu/run.yaml +++ b/llama_stack/distributions/starter-gpu/run.yaml @@ -89,28 +89,28 @@ providers: config: kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/faiss_store.db - provider_id: sqlite-vec provider_type: inline::sqlite-vec config: - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/sqlite_vec_registry.db - provider_id: ${env.MILVUS_URL:+milvus} provider_type: inline::milvus config: - db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter-gpu}/milvus.db kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/milvus_registry.db - provider_id: ${env.CHROMADB_URL:+chromadb} provider_type: remote::chromadb config: url: ${env.CHROMADB_URL:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter/}/chroma_remote_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu/}/chroma_remote_registry.db - provider_id: ${env.PGVECTOR_DB:+pgvector} provider_type: remote::pgvector config: @@ -121,15 +121,15 @@ providers: password: ${env.PGVECTOR_PASSWORD:=} kvstore: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/pgvector_registry.db files: - provider_id: meta-reference-files provider_type: inline::localfs config: - storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter-gpu/files} metadata_store: type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter-gpu}/files_metadata.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/distributions/starter-gpu/starter_gpu.py b/llama_stack/distributions/starter-gpu/starter_gpu.py index 245334749..e7efcb283 100644 --- a/llama_stack/distributions/starter-gpu/starter_gpu.py +++ b/llama_stack/distributions/starter-gpu/starter_gpu.py @@ -11,9 +11,7 @@ from ..starter.starter import get_distribution_template as get_starter_distribut def get_distribution_template() -> DistributionTemplate: - template = get_starter_distribution_template() - name = "starter-gpu" - template.name = name + template = get_starter_distribution_template(name="starter-gpu") template.description = "Quick start template for running Llama Stack with several popular providers. This distribution is intended for GPU-enabled environments." template.providers["post_training"] = [ diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index a4bbc6371..2fca52700 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -99,9 +99,8 @@ def get_remote_inference_providers() -> list[Provider]: return inference_providers -def get_distribution_template() -> DistributionTemplate: +def get_distribution_template(name: str = "starter") -> DistributionTemplate: remote_inference_providers = get_remote_inference_providers() - name = "starter" providers = { "inference": [BuildProvider(provider_type=p.provider_type, module=p.module) for p in remote_inference_providers] From e2fe39aee108c1796ddc1be1e44acd25c800082e Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Fri, 5 Sep 2025 07:40:34 -0600 Subject: [PATCH 10/17] feat!: Migrate Vector DB IDs to Vector Store IDs (breaking change) (#3253) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This change migrates the VectorDB id generation to Vector Stores. This is a breaking change for **_some users_** that may have application code using the `vector_db_id` parameter in the request of the VectorDB protocol instead of the `VectorDB.identifier` in the response. By default we will now create a Vector Store every time we register a VectorDB. The caveat with this approach is that this maps the `vector_db_id` → `vector_store.name`. This is a reasonable tradeoff to transition users towards OpenAI Vector Stores. As an added benefit, registering VectorDBs will result in them appearing in the VectorStores admin UI. ### Why? This PR makes the `POST` API call to `/v1/vector-dbs` swap the `vector_db_id` parameter in the **request body** into the VectorStore's name field and sets the `vector_db_id` to the generated vector store id (e.g., `vs_038247dd-4bbb-4dbb-a6be-d5ecfd46cfdb`). That means that users would have to do something like follows in their application code: ```python res = client.vector_dbs.register( vector_db_id='my-vector-db-id', embedding_model='ollama/all-minilm:l6-v2', embedding_dimension=384, ) vector_db_id = res.identifier ``` And then the rest of their code would behave, including `VectorIO`'s insert protocol using `vector_db_id` in the request. An alternative implementation would be to just delete the `vector_db_id` parameter in `VectorDB` but the end result would still require users having to write `vector_db_id = res.identifier` since `VectorStores.create()` generates the ID for you. So this approach felt the easiest way to migrate users towards VectorStores (subsequent PRs will be added to trigger `files.create()` and `vector_stores.files.create()`). ## Test Plan Unit tests and integration tests have been added. Signed-off-by: Francisco Javier Arceo --- llama_stack/core/routing_tables/vector_dbs.py | 26 +++- tests/integration/vector_io/test_vector_io.py | 75 +++++++---- .../routers/test_routing_tables.py | 30 ++++- .../routing_tables/test_vector_dbs.py | 127 ++++++++++++++++-- 4 files changed, 209 insertions(+), 49 deletions(-) diff --git a/llama_stack/core/routing_tables/vector_dbs.py b/llama_stack/core/routing_tables/vector_dbs.py index 00f71b4fe..497894064 100644 --- a/llama_stack/core/routing_tables/vector_dbs.py +++ b/llama_stack/core/routing_tables/vector_dbs.py @@ -52,7 +52,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): provider_vector_db_id: str | None = None, vector_db_name: str | None = None, ) -> VectorDB: - provider_vector_db_id = provider_vector_db_id or vector_db_id if provider_id is None: if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] @@ -69,14 +68,33 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): raise ModelTypeError(embedding_model, model.model_type, ModelType.embedding) if "embedding_dimension" not in model.metadata: raise ValueError(f"Model {embedding_model} does not have an embedding dimension") + + provider = self.impls_by_provider_id[provider_id] + logger.warning( + "VectorDB is being deprecated in future releases in favor of VectorStore. Please migrate your usage accordingly." + ) + vector_store = await provider.openai_create_vector_store( + name=vector_db_name or vector_db_id, + embedding_model=embedding_model, + embedding_dimension=model.metadata["embedding_dimension"], + provider_id=provider_id, + provider_vector_db_id=provider_vector_db_id, + ) + + vector_store_id = vector_store.id + actual_provider_vector_db_id = provider_vector_db_id or vector_store_id + logger.warning( + f"Ignoring vector_db_id {vector_db_id} and using vector_store_id {vector_store_id} instead. Setting VectorDB {vector_db_id} to VectorDB.vector_db_name" + ) + vector_db_data = { - "identifier": vector_db_id, + "identifier": vector_store_id, "type": ResourceType.vector_db.value, "provider_id": provider_id, - "provider_resource_id": provider_vector_db_id, + "provider_resource_id": actual_provider_vector_db_id, "embedding_model": embedding_model, "embedding_dimension": model.metadata["embedding_dimension"], - "vector_db_name": vector_db_name, + "vector_db_name": vector_store.name, } vector_db = TypeAdapter(VectorDBWithOwner).validate_python(vector_db_data) await self.register_object(vector_db) diff --git a/tests/integration/vector_io/test_vector_io.py b/tests/integration/vector_io/test_vector_io.py index 07faa0db1..979eff6bb 100644 --- a/tests/integration/vector_io/test_vector_io.py +++ b/tests/integration/vector_io/test_vector_io.py @@ -47,34 +47,45 @@ def client_with_empty_registry(client_with_models): def test_vector_db_retrieve(client_with_empty_registry, embedding_model_id, embedding_dimension): - # Register a memory bank first - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + # Retrieve the memory bank and validate its properties - response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=vector_db_id) + response = client_with_empty_registry.vector_dbs.retrieve(vector_db_id=actual_vector_db_id) assert response is not None - assert response.identifier == vector_db_id + assert response.identifier == actual_vector_db_id assert response.embedding_model == embedding_model_id - assert response.provider_resource_id == vector_db_id + assert response.identifier.startswith("vs_") def test_vector_db_register(client_with_empty_registry, embedding_model_id, embedding_dimension): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) - vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_dbs_after_register == [vector_db_id] + actual_vector_db_id = response.identifier + assert actual_vector_db_id.startswith("vs_") + assert actual_vector_db_id != vector_db_name - client_with_empty_registry.vector_dbs.unregister(vector_db_id=vector_db_id) + vector_dbs_after_register = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] + assert vector_dbs_after_register == [actual_vector_db_id] + + vector_stores = client_with_empty_registry.vector_stores.list() + assert len(vector_stores.data) == 1 + vector_store = vector_stores.data[0] + assert vector_store.id == actual_vector_db_id + assert vector_store.name == vector_db_name + + client_with_empty_registry.vector_dbs.unregister(vector_db_id=actual_vector_db_id) vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] assert len(vector_dbs) == 0 @@ -91,20 +102,22 @@ def test_vector_db_register(client_with_empty_registry, embedding_model_id, embe ], ) def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding_dimension, sample_chunks, test_case): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=sample_chunks, ) response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="What is the capital of France?", ) assert response is not None @@ -113,7 +126,7 @@ def test_insert_chunks(client_with_empty_registry, embedding_model_id, embedding query, expected_doc_id = test_case response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query=query, ) assert response is not None @@ -128,13 +141,15 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e "remote::qdrant": {"score_threshold": -1.0}, "inline::qdrant": {"score_threshold": -1.0}, } - vector_db_id = "test_precomputed_embeddings_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_precomputed_embeddings_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + chunks_with_embeddings = [ Chunk( content="This is a test chunk with precomputed embedding.", @@ -144,13 +159,13 @@ def test_insert_chunks_with_precomputed_embeddings(client_with_empty_registry, e ] client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=chunks_with_embeddings, ) provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="precomputed embedding test", params=vector_io_provider_params_dict.get(provider, None), ) @@ -173,13 +188,15 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( "remote::qdrant": {"score_threshold": 0.0}, "inline::qdrant": {"score_threshold": 0.0}, } - vector_db_id = "test_precomputed_embeddings_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_precomputed_embeddings_db" + register_response = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + actual_vector_db_id = register_response.identifier + chunks_with_embeddings = [ Chunk( content="duplicate", @@ -189,13 +206,13 @@ def test_query_returns_valid_object_when_identical_to_embedding_in_vdb( ] client_with_empty_registry.vector_io.insert( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunks=chunks_with_embeddings, ) provider = [p.provider_id for p in client_with_empty_registry.providers.list() if p.api == "vector_io"][0] response = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="duplicate", params=vector_io_provider_params_dict.get(provider, None), ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 2652f5c8d..1ceee81c6 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -146,6 +146,20 @@ class VectorDBImpl(Impl): async def unregister_vector_db(self, vector_db_id: str): return vector_db_id + async def openai_create_vector_store(self, **kwargs): + import time + import uuid + + from llama_stack.apis.vector_io.vector_io import VectorStoreFileCounts, VectorStoreObject + + vector_store_id = kwargs.get("provider_vector_db_id") or f"vs_{uuid.uuid4()}" + return VectorStoreObject( + id=vector_store_id, + name=kwargs.get("name", vector_store_id), + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + async def test_models_routing_table(cached_disk_dist_registry): table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) @@ -247,17 +261,21 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): ) # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") + vdb1 = await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test_provider/test-model") + vdb2 = await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test_provider/test-model") vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 2 vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids + assert vdb1.identifier in vector_db_ids + assert vdb2.identifier in vector_db_ids - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") + # Verify they have UUID-based identifiers + assert vdb1.identifier.startswith("vs_") + assert vdb2.identifier.startswith("vs_") + + await table.unregister_vector_db(vector_db_id=vdb1.identifier) + await table.unregister_vector_db(vector_db_id=vdb2.identifier) vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 0 diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py index 789eda433..3444f64c2 100644 --- a/tests/unit/distribution/routing_tables/test_vector_dbs.py +++ b/tests/unit/distribution/routing_tables/test_vector_dbs.py @@ -7,6 +7,7 @@ # Unit tests for the routing tables vector_dbs import time +import uuid from unittest.mock import AsyncMock import pytest @@ -34,6 +35,7 @@ from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceI class VectorDBImpl(Impl): def __init__(self): super().__init__(Api.vector_io) + self.vector_stores = {} async def register_vector_db(self, vector_db: VectorDB): return vector_db @@ -114,8 +116,35 @@ class VectorDBImpl(Impl): async def openai_delete_vector_store_file(self, vector_store_id, file_id): return VectorStoreFileDeleteResponse(id=file_id, deleted=True) + async def openai_create_vector_store( + self, + name=None, + embedding_model=None, + embedding_dimension=None, + provider_id=None, + provider_vector_db_id=None, + **kwargs, + ): + vector_store_id = provider_vector_db_id or f"vs_{uuid.uuid4()}" + vector_store = VectorStoreObject( + id=vector_store_id, + name=name or vector_store_id, + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + self.vector_stores[vector_store_id] = vector_store + return vector_store + + async def openai_list_vector_stores(self, **kwargs): + from llama_stack.apis.vector_io.vector_io import VectorStoreListResponse + + return VectorStoreListResponse( + data=list(self.vector_stores.values()), has_more=False, first_id=None, last_id=None + ) + async def test_vectordbs_routing_table(cached_disk_dist_registry): + n = 10 table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) await table.initialize() @@ -129,22 +158,98 @@ async def test_vectordbs_routing_table(cached_disk_dist_registry): ) # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") + vdb_dict = {} + for i in range(n): + vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") + vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 2 + assert len(vector_dbs.data) == len(vdb_dict) vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids - - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") + for k in vdb_dict: + assert vdb_dict[k].identifier in vector_db_ids + for k in vdb_dict: + await table.unregister_vector_db(vector_db_id=vdb_dict[k].identifier) vector_dbs = await table.list_vector_dbs() assert len(vector_dbs.data) == 0 +async def test_vector_db_and_vector_store_id_mapping(cached_disk_dist_registry): + n = 10 + impl = VectorDBImpl() + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + vdb_dict = {} + for i in range(n): + vdb_dict[i] = await table.register_vector_db(vector_db_id=f"test-vectordb-{i}", embedding_model="test-model") + + vector_dbs = await table.list_vector_dbs() + vector_db_ids = {v.identifier for v in vector_dbs.data} + + vector_stores = await impl.openai_list_vector_stores() + vector_store_ids = {v.id for v in vector_stores.data} + + assert vector_db_ids == vector_store_ids, ( + f"Vector DB IDs {vector_db_ids} don't match vector store IDs {vector_store_ids}" + ) + + for vector_store in vector_stores.data: + vector_db = await table.get_vector_db(vector_store.id) + assert vector_store.name == vector_db.vector_db_name, ( + f"Vector store name {vector_store.name} doesn't match vector store ID {vector_store.id}" + ) + + for vector_db_id in vector_db_ids: + await table.unregister_vector_db(vector_db_id) + + assert len((await table.list_vector_dbs()).data) == 0 + + +async def test_vector_db_id_becomes_vector_store_name(cached_disk_dist_registry): + impl = VectorDBImpl() + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + user_provided_id = "my-custom-vector-db" + await table.register_vector_db(vector_db_id=user_provided_id, embedding_model="test-model") + + vector_stores = await impl.openai_list_vector_stores() + assert len(vector_stores.data) == 1 + + vector_store = vector_stores.data[0] + + assert vector_store.name == user_provided_id + + assert vector_store.id.startswith("vs_") + assert vector_store.id != user_provided_id + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 1 + assert vector_dbs.data[0].identifier == vector_store.id + + await table.unregister_vector_db(vector_store.id) + + async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): impl = VectorDBImpl() impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") @@ -164,7 +269,8 @@ async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registr authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) with request_provider_data_context({}, authorized_user): - _ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + registered_vdb = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + authorized_table = registered_vdb.identifier # Use the actual generated ID # Authorized reader with request_provider_data_context({}, authorized_user): @@ -227,7 +333,8 @@ async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_regis ) with request_provider_data_context({}, admin_user): - await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + registered_vdb = await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + vector_db_id = registered_vdb.identifier # Use the actual generated ID read_methods = [ (table.openai_retrieve_vector_store, (vector_db_id,), {}), From df1526991f6d5cfca05d3c4f1077b67f4832d93e Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 5 Sep 2025 14:59:57 -0400 Subject: [PATCH 11/17] feat(batches, completions): add /v1/completions support to /v1/batches (#3309) # What does this PR do? add support for /v1/completions to the /v1/batches api ## Test Plan ci --- .../inline/batches/reference/batches.py | 69 +++++++++++++------ tests/integration/batches/test_batches.py | 55 +++++++++++++++ .../recordings/responses/41e27b9b5d09.json | 42 +++++++++++ .../unit/providers/batches/test_reference.py | 65 +++++++++++++++-- 4 files changed, 205 insertions(+), 26 deletions(-) create mode 100644 tests/integration/recordings/responses/41e27b9b5d09.json diff --git a/llama_stack/providers/inline/batches/reference/batches.py b/llama_stack/providers/inline/batches/reference/batches.py index 26f0ad15a..e049518a4 100644 --- a/llama_stack/providers/inline/batches/reference/batches.py +++ b/llama_stack/providers/inline/batches/reference/batches.py @@ -178,9 +178,9 @@ class ReferenceBatchesImpl(Batches): # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions"]: + if endpoint not in ["/v1/chat/completions", "/v1/completions"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions. Code: invalid_value. Param: endpoint", ) if completion_window != "24h": @@ -424,13 +424,21 @@ class ReferenceBatchesImpl(Batches): ) valid = False - for param, expected_type, type_string in [ - ("model", str, "a string"), - # messages is specific to /v1/chat/completions - # we could skip validating messages here and let inference fail. however, - # that would be a very expensive way to find out messages is wrong. - ("messages", list, "an array"), # TODO: allow messages to be a string? - ]: + if batch.endpoint == "/v1/chat/completions": + required_params = [ + ("model", str, "a string"), + # messages is specific to /v1/chat/completions + # we could skip validating messages here and let inference fail. however, + # that would be a very expensive way to find out messages is wrong. + ("messages", list, "an array"), # TODO: allow messages to be a string? + ] + else: # /v1/completions + required_params = [ + ("model", str, "a string"), + ("prompt", str, "a string"), # TODO: allow prompt to be a list of strings?? + ] + + for param, expected_type, type_string in required_params: if param not in body: errors.append( BatchError( @@ -591,20 +599,37 @@ class ReferenceBatchesImpl(Batches): try: # TODO(SECURITY): review body for security issues - request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] - chat_response = await self.inference_api.openai_chat_completion(**request.body) + if request.url == "/v1/chat/completions": + request.body["messages"] = [convert_to_openai_message_param(msg) for msg in request.body["messages"]] + chat_response = await self.inference_api.openai_chat_completion(**request.body) - # this is for mypy, we don't allow streaming so we'll get the right type - assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" - return { - "id": request_id, - "custom_id": request.custom_id, - "response": { - "status_code": 200, - "request_id": request_id, # TODO: should this be different? - "body": chat_response.model_dump_json(), - }, - } + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(chat_response, "model_dump_json"), "Chat response must have model_dump_json method" + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, # TODO: should this be different? + "body": chat_response.model_dump_json(), + }, + } + else: # /v1/completions + completion_response = await self.inference_api.openai_completion(**request.body) + + # this is for mypy, we don't allow streaming so we'll get the right type + assert hasattr(completion_response, "model_dump_json"), ( + "Completion response must have model_dump_json method" + ) + return { + "id": request_id, + "custom_id": request.custom_id, + "response": { + "status_code": 200, + "request_id": request_id, + "body": completion_response.model_dump_json(), + }, + } except Exception as e: logger.info(f"Error processing request {request.custom_id} in batch {batch_id}: {e}") return { diff --git a/tests/integration/batches/test_batches.py b/tests/integration/batches/test_batches.py index 59811b7a4..d55a68bd3 100644 --- a/tests/integration/batches/test_batches.py +++ b/tests/integration/batches/test_batches.py @@ -268,3 +268,58 @@ class TestBatchesIntegration: deleted_error_file = openai_client.files.delete(final_batch.error_file_id) assert deleted_error_file.deleted, f"Error file {final_batch.error_file_id} was not deleted successfully" + + def test_batch_e2e_completions(self, openai_client, batch_helper, text_model_id): + """Run an end-to-end batch with a single successful text completion request.""" + request_body = {"model": text_model_id, "prompt": "Say completions", "max_tokens": 20} + + batch_requests = [ + { + "custom_id": "success-1", + "method": "POST", + "url": "/v1/completions", + "body": request_body, + } + ] + + with batch_helper.create_file(batch_requests) as uploaded_file: + batch = openai_client.batches.create( + input_file_id=uploaded_file.id, + endpoint="/v1/completions", + completion_window="24h", + metadata={"test": "e2e_completions_success"}, + ) + + final_batch = batch_helper.wait_for( + batch.id, + max_wait_time=3 * 60, + expected_statuses={"completed"}, + timeout_action="skip", + ) + + assert final_batch.status == "completed" + assert final_batch.request_counts is not None + assert final_batch.request_counts.total == 1 + assert final_batch.request_counts.completed == 1 + assert final_batch.output_file_id is not None + + output_content = openai_client.files.content(final_batch.output_file_id) + if isinstance(output_content, str): + output_text = output_content + else: + output_text = output_content.content.decode("utf-8") + + output_lines = output_text.strip().split("\n") + assert len(output_lines) == 1 + + result = json.loads(output_lines[0]) + assert result["custom_id"] == "success-1" + assert "response" in result + assert result["response"]["status_code"] == 200 + + deleted_output_file = openai_client.files.delete(final_batch.output_file_id) + assert deleted_output_file.deleted + + if final_batch.error_file_id is not None: + deleted_error_file = openai_client.files.delete(final_batch.error_file_id) + assert deleted_error_file.deleted diff --git a/tests/integration/recordings/responses/41e27b9b5d09.json b/tests/integration/recordings/responses/41e27b9b5d09.json new file mode 100644 index 000000000..45d140843 --- /dev/null +++ b/tests/integration/recordings/responses/41e27b9b5d09.json @@ -0,0 +1,42 @@ +{ + "request": { + "method": "POST", + "url": "http://0.0.0.0:11434/v1/v1/completions", + "headers": {}, + "body": { + "model": "llama3.2:3b-instruct-fp16", + "prompt": "Say completions", + "max_tokens": 20 + }, + "endpoint": "/v1/completions", + "model": "llama3.2:3b-instruct-fp16" + }, + "response": { + "body": { + "__type__": "openai.types.completion.Completion", + "__data__": { + "id": "cmpl-271", + "choices": [ + { + "finish_reason": "length", + "index": 0, + "logprobs": null, + "text": "You want me to respond with a completion, but you didn't specify what I should complete. Could" + } + ], + "created": 1756846620, + "model": "llama3.2:3b-instruct-fp16", + "object": "text_completion", + "system_fingerprint": "fp_ollama", + "usage": { + "completion_tokens": 20, + "prompt_tokens": 28, + "total_tokens": 48, + "completion_tokens_details": null, + "prompt_tokens_details": null + } + } + }, + "is_streaming": false + } +} diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index 0ca866f7b..dfef5e040 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -46,7 +46,8 @@ The tests are categorized and outlined below, keep this updated: * test_validate_input_url_mismatch (negative) * test_validate_input_multiple_errors_per_request (negative) * test_validate_input_invalid_request_format (negative) - * test_validate_input_missing_parameters (parametrized negative - custom_id, method, url, body, model, messages missing validation) + * test_validate_input_missing_parameters_chat_completions (parametrized negative - custom_id, method, url, body, model, messages missing validation for chat/completions) + * test_validate_input_missing_parameters_completions (parametrized negative - custom_id, method, url, body, model, prompt missing validation for completions) * test_validate_input_invalid_parameter_types (parametrized negative - custom_id, url, method, body, model, messages type validation) The tests use temporary SQLite databases for isolation and mock external @@ -213,7 +214,6 @@ class TestReferenceBatchesImpl: "endpoint", [ "/v1/embeddings", - "/v1/completions", "/v1/invalid/endpoint", "", ], @@ -499,8 +499,10 @@ class TestReferenceBatchesImpl: ("messages", "body.messages", "invalid_request", "Messages parameter is required"), ], ) - async def test_validate_input_missing_parameters(self, provider, param_name, param_path, error_code, error_message): - """Test _validate_input when file contains request with missing required parameters.""" + async def test_validate_input_missing_parameters_chat_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for chat completions.""" provider.files_api.openai_retrieve_file = AsyncMock() mock_response = MagicMock() @@ -541,6 +543,61 @@ class TestReferenceBatchesImpl: assert errors[0].message == error_message assert errors[0].param == param_path + @pytest.mark.parametrize( + "param_name,param_path,error_code,error_message", + [ + ("custom_id", "custom_id", "missing_required_parameter", "Missing required parameter: custom_id"), + ("method", "method", "missing_required_parameter", "Missing required parameter: method"), + ("url", "url", "missing_required_parameter", "Missing required parameter: url"), + ("body", "body", "missing_required_parameter", "Missing required parameter: body"), + ("model", "body.model", "invalid_request", "Model parameter is required"), + ("prompt", "body.prompt", "invalid_request", "Prompt parameter is required"), + ], + ) + async def test_validate_input_missing_parameters_completions( + self, provider, param_name, param_path, error_code, error_message + ): + """Test _validate_input when file contains request with missing required parameters for text completions.""" + provider.files_api.openai_retrieve_file = AsyncMock() + mock_response = MagicMock() + + base_request = { + "custom_id": "req-1", + "method": "POST", + "url": "/v1/completions", + "body": {"model": "test-model", "prompt": "Hello"}, + } + + # Remove the specific parameter being tested + if "." in param_path: + top_level, nested_param = param_path.split(".", 1) + del base_request[top_level][nested_param] + else: + del base_request[param_name] + + mock_response.body = json.dumps(base_request).encode() + provider.files_api.openai_retrieve_file_content = AsyncMock(return_value=mock_response) + + batch = BatchObject( + id="batch_test", + object="batch", + endpoint="/v1/completions", + input_file_id=f"missing_{param_name}_file", + completion_window="24h", + status="validating", + created_at=1234567890, + ) + + errors, requests = await provider._validate_input(batch) + + assert len(errors) == 1 + assert len(requests) == 0 + + assert errors[0].code == error_code + assert errors[0].line == 1 + assert errors[0].message == error_message + assert errors[0].param == param_path + async def test_validate_input_url_mismatch(self, provider): """Test _validate_input when file contains request with URL that doesn't match batch endpoint.""" provider.files_api.openai_retrieve_file = AsyncMock() From 0c2757a05b504bafef1bc589376712b9ac9a1c52 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 5 Sep 2025 15:00:09 -0400 Subject: [PATCH 12/17] chore(sambanova test): skip with_n tests for sambanova, it is not implemented server-side (#3342) # What does this PR do? skip a test that cannot pass for sambanova see https://docs-legacy.sambanova.ai/sambastudio/latest/open-ai-api.html\#_example_requests_using_openai_client ## Test Plan ci --- .../inference/test_openai_completion.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 62185e470..bb447b3c1 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -58,6 +58,15 @@ def skip_if_model_doesnt_support_suffix(client_with_models, model_id): pytest.skip(f"Provider {provider.provider_type} doesn't support suffix.") +def skip_if_doesnt_support_n(client_with_models, model_id): + provider = provider_from_model(client_with_models, model_id) + if provider.provider_type in ( + "remote::sambanova", + "remote::ollama", + ): + pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") + + def skip_if_model_doesnt_support_openai_chat_completion(client_with_models, model_id): provider = provider_from_model(client_with_models, model_id) if provider.provider_type in ( @@ -262,10 +271,7 @@ def test_openai_chat_completion_streaming(compat_client, client_with_models, tex ) def test_openai_chat_completion_streaming_with_n(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) - - provider = provider_from_model(client_with_models, text_model_id) - if provider.provider_type == "remote::ollama": - pytest.skip(f"Model {text_model_id} hosted by {provider.provider_type} doesn't support n > 1.") + skip_if_doesnt_support_n(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] From 47b640370e275dd178c0bb9fcf822467e032120d Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 5 Sep 2025 13:58:49 -0700 Subject: [PATCH 13/17] feat(tests): introduce a test "suite" concept to encompass dirs, options (#3339) Our integration tests need to be 'grouped' because each group often needs a specific set of models it works with. We separated vision tests due to this, and we have a separate set of tests which test "Responses" API. This PR makes this system a bit more official so it is very easy to target these groups and apply all testing infrastructure towards all the groups (for example, record-replay) uniformly. There are three suites declared: - base - vision - responses Note that our CI currently runs the "base" and "vision" suites. You can use the `--suite` option when running pytest (or any of the testing scripts or workflows.) For example: ``` OLLAMA_URL=http://localhost:11434 \ pytest -s -v tests/integration/ --stack-config starter --suite vision ``` --- .../actions/run-and-record-tests/action.yml | 30 ++-- .github/actions/setup-ollama/action.yml | 8 +- .../actions/setup-test-environment/action.yml | 8 +- .github/workflows/README.md | 2 +- .github/workflows/integration-tests.yml | 20 +-- .../workflows/record-integration-tests.yml | 32 ++--- scripts/github/schedule-record-workflow.sh | 32 +++-- scripts/integration-tests.sh | 133 +++++++----------- tests/README.md | 2 +- tests/integration/README.md | 21 +++ tests/integration/conftest.py | 75 ++++++++-- .../{non_ci => }/responses/__init__.py | 0 .../responses/fixtures/__init__.py | 0 .../responses/fixtures/fixtures.py | 0 .../fixtures/images/vision_test_1.jpg | Bin .../fixtures/images/vision_test_2.jpg | Bin .../fixtures/images/vision_test_3.jpg | Bin .../fixtures/pdfs/llama_stack_and_models.pdf | Bin .../responses/fixtures/test_cases.py | 0 .../{non_ci => }/responses/helpers.py | 0 .../responses/streaming_assertions.py | 0 .../responses/test_basic_responses.py | 0 .../responses/test_file_search.py | 0 .../responses/test_tool_responses.py | 0 tests/integration/suites.py | 53 +++++++ 25 files changed, 255 insertions(+), 161 deletions(-) rename tests/integration/{non_ci => }/responses/__init__.py (100%) rename tests/integration/{non_ci => }/responses/fixtures/__init__.py (100%) rename tests/integration/{non_ci => }/responses/fixtures/fixtures.py (100%) rename tests/integration/{non_ci => }/responses/fixtures/images/vision_test_1.jpg (100%) rename tests/integration/{non_ci => }/responses/fixtures/images/vision_test_2.jpg (100%) rename tests/integration/{non_ci => }/responses/fixtures/images/vision_test_3.jpg (100%) rename tests/integration/{non_ci => }/responses/fixtures/pdfs/llama_stack_and_models.pdf (100%) rename tests/integration/{non_ci => }/responses/fixtures/test_cases.py (100%) rename tests/integration/{non_ci => }/responses/helpers.py (100%) rename tests/integration/{non_ci => }/responses/streaming_assertions.py (100%) rename tests/integration/{non_ci => }/responses/test_basic_responses.py (100%) rename tests/integration/{non_ci => }/responses/test_file_search.py (100%) rename tests/integration/{non_ci => }/responses/test_tool_responses.py (100%) create mode 100644 tests/integration/suites.py diff --git a/.github/actions/run-and-record-tests/action.yml b/.github/actions/run-and-record-tests/action.yml index 60550cfdc..7f028b104 100644 --- a/.github/actions/run-and-record-tests/action.yml +++ b/.github/actions/run-and-record-tests/action.yml @@ -2,13 +2,6 @@ name: 'Run and Record Tests' description: 'Run integration tests and handle recording/artifact upload' inputs: - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - required: true - test-pattern: - description: 'Regex pattern to pass to pytest -k' - required: false - default: '' stack-config: description: 'Stack configuration to use' required: true @@ -18,10 +11,18 @@ inputs: inference-mode: description: 'Inference mode (record or replay)' required: true - run-vision-tests: - description: 'Whether to run vision tests' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run; overrides test-suite' + required: false + default: '' + test-pattern: + description: 'Regex pattern to pass to pytest -k' + required: false + default: '' runs: using: 'composite' @@ -42,7 +43,7 @@ runs: --test-subdirs '${{ inputs.test-subdirs }}' \ --test-pattern '${{ inputs.test-pattern }}' \ --inference-mode '${{ inputs.inference-mode }}' \ - ${{ inputs.run-vision-tests == 'true' && '--run-vision-tests' || '' }} \ + --test-suite '${{ inputs.test-suite }}' \ | tee pytest-${{ inputs.inference-mode }}.log @@ -57,12 +58,7 @@ runs: echo "New recordings detected, committing and pushing" git add tests/integration/recordings/ - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then - git commit -m "Recordings update from CI (vision)" - else - git commit -m "Recordings update from CI" - fi - + git commit -m "Recordings update from CI (test-suite: ${{ inputs.test-suite }})" git fetch origin ${{ github.ref_name }} git rebase origin/${{ github.ref_name }} echo "Rebased successfully" diff --git a/.github/actions/setup-ollama/action.yml b/.github/actions/setup-ollama/action.yml index e57876cb0..dc2f87e8c 100644 --- a/.github/actions/setup-ollama/action.yml +++ b/.github/actions/setup-ollama/action.yml @@ -1,17 +1,17 @@ name: Setup Ollama description: Start Ollama inputs: - run-vision-tests: - description: 'Run vision tests: "true" or "false"' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' runs: using: "composite" steps: - name: Start Ollama shell: bash run: | - if [ "${{ inputs.run-vision-tests }}" == "true" ]; then + if [ "${{ inputs.test-suite }}" == "vision" ]; then image="ollama-with-vision-model" else image="ollama-with-models" diff --git a/.github/actions/setup-test-environment/action.yml b/.github/actions/setup-test-environment/action.yml index d830e3d13..3be76f009 100644 --- a/.github/actions/setup-test-environment/action.yml +++ b/.github/actions/setup-test-environment/action.yml @@ -12,10 +12,10 @@ inputs: description: 'Provider to setup (ollama or vllm)' required: true default: 'ollama' - run-vision-tests: - description: 'Whether to setup provider for vision tests' + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' required: false - default: 'false' + default: '' inference-mode: description: 'Inference mode (record or replay)' required: true @@ -33,7 +33,7 @@ runs: if: ${{ inputs.provider == 'ollama' && inputs.inference-mode == 'record' }} uses: ./.github/actions/setup-ollama with: - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} - name: Setup vllm if: ${{ inputs.provider == 'vllm' && inputs.inference-mode == 'record' }} diff --git a/.github/workflows/README.md b/.github/workflows/README.md index 8344d12a4..2e0df58b8 100644 --- a/.github/workflows/README.md +++ b/.github/workflows/README.md @@ -8,7 +8,7 @@ Llama Stack uses GitHub Actions for Continuous Integration (CI). Below is a tabl | Installer CI | [install-script-ci.yml](install-script-ci.yml) | Test the installation script | | Integration Auth Tests | [integration-auth-tests.yml](integration-auth-tests.yml) | Run the integration test suite with Kubernetes authentication | | SqlStore Integration Tests | [integration-sql-store-tests.yml](integration-sql-store-tests.yml) | Run the integration test suite with SqlStore | -| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suite from tests/integration in replay mode | +| Integration Tests (Replay) | [integration-tests.yml](integration-tests.yml) | Run the integration test suites from tests/integration in replay mode | | Vector IO Integration Tests | [integration-vector-io-tests.yml](integration-vector-io-tests.yml) | Run the integration test suite with various VectorIO providers | | Pre-commit | [pre-commit.yml](pre-commit.yml) | Run pre-commit checks | | Test Llama Stack Build | [providers-build.yml](providers-build.yml) | Test llama stack build | diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 57e582b20..bb53eea2f 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,6 +1,6 @@ name: Integration Tests (Replay) -run-name: Run the integration test suite from tests/integration in replay mode +run-name: Run the integration test suites from tests/integration in replay mode on: push: @@ -32,14 +32,6 @@ on: description: 'Test against a specific provider' type: string default: 'ollama' - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - type: string - default: '' - test-pattern: - description: 'Regex pattern to pass to pytest -k' - type: string - default: '' concurrency: # Skip concurrency for pushes to main - each commit should be tested independently @@ -50,7 +42,7 @@ jobs: run-replay-mode-tests: runs-on: ubuntu-latest - name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, vision={4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.run-vision-tests) }} + name: ${{ format('Integration Tests ({0}, {1}, {2}, client={3}, {4})', matrix.client-type, matrix.provider, matrix.python-version, matrix.client-version, matrix.test-suite) }} strategy: fail-fast: false @@ -61,7 +53,7 @@ jobs: # Use Python 3.13 only on nightly schedule (daily latest client test), otherwise use 3.12 python-version: ${{ github.event.schedule == '0 0 * * *' && fromJSON('["3.12", "3.13"]') || fromJSON('["3.12"]') }} client-version: ${{ (github.event.schedule == '0 0 * * *' || github.event.inputs.test-all-client-versions == 'true') && fromJSON('["published", "latest"]') || fromJSON('["latest"]') }} - run-vision-tests: [true, false] + test-suite: [base, vision] steps: - name: Checkout repository @@ -73,15 +65,13 @@ jobs: python-version: ${{ matrix.python-version }} client-version: ${{ matrix.client-version }} provider: ${{ matrix.provider }} - run-vision-tests: ${{ matrix.run-vision-tests }} + test-suite: ${{ matrix.test-suite }} inference-mode: 'replay' - name: Run tests uses: ./.github/actions/run-and-record-tests with: - test-subdirs: ${{ inputs.test-subdirs }} - test-pattern: ${{ inputs.test-pattern }} stack-config: ${{ matrix.client-type == 'library' && 'ci-tests' || 'server:ci-tests' }} provider: ${{ matrix.provider }} inference-mode: 'replay' - run-vision-tests: ${{ matrix.run-vision-tests }} + test-suite: ${{ matrix.test-suite }} diff --git a/.github/workflows/record-integration-tests.yml b/.github/workflows/record-integration-tests.yml index d4f5586e2..01797a54b 100644 --- a/.github/workflows/record-integration-tests.yml +++ b/.github/workflows/record-integration-tests.yml @@ -10,18 +10,18 @@ run-name: Run the integration test suite from tests/integration on: workflow_dispatch: inputs: - test-subdirs: - description: 'Comma-separated list of test subdirectories to run' - type: string - default: '' test-provider: description: 'Test against a specific provider' type: string default: 'ollama' - run-vision-tests: - description: 'Whether to run vision tests' - type: boolean - default: false + test-suite: + description: 'Test suite to use: base, responses, vision, etc.' + type: string + default: '' + test-subdirs: + description: 'Comma-separated list of test subdirectories to run; overrides test-suite' + type: string + default: '' test-pattern: description: 'Regex pattern to pass to pytest -k' type: string @@ -38,11 +38,11 @@ jobs: - name: Echo workflow inputs run: | echo "::group::Workflow Inputs" - echo "test-subdirs: ${{ inputs.test-subdirs }}" - echo "test-provider: ${{ inputs.test-provider }}" - echo "run-vision-tests: ${{ inputs.run-vision-tests }}" - echo "test-pattern: ${{ inputs.test-pattern }}" echo "branch: ${{ github.ref_name }}" + echo "test-provider: ${{ inputs.test-provider }}" + echo "test-suite: ${{ inputs.test-suite }}" + echo "test-subdirs: ${{ inputs.test-subdirs }}" + echo "test-pattern: ${{ inputs.test-pattern }}" echo "::endgroup::" - name: Checkout repository @@ -56,15 +56,15 @@ jobs: python-version: "3.12" # Use single Python version for recording client-version: "latest" provider: ${{ inputs.test-provider || 'ollama' }} - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} inference-mode: 'record' - name: Run and record tests uses: ./.github/actions/run-and-record-tests with: - test-pattern: ${{ inputs.test-pattern }} - test-subdirs: ${{ inputs.test-subdirs }} stack-config: 'server:ci-tests' # recording must be done with server since more tests are run provider: ${{ inputs.test-provider || 'ollama' }} inference-mode: 'record' - run-vision-tests: ${{ inputs.run-vision-tests }} + test-suite: ${{ inputs.test-suite }} + test-subdirs: ${{ inputs.test-subdirs }} + test-pattern: ${{ inputs.test-pattern }} diff --git a/scripts/github/schedule-record-workflow.sh b/scripts/github/schedule-record-workflow.sh index e381b60b6..09e055611 100755 --- a/scripts/github/schedule-record-workflow.sh +++ b/scripts/github/schedule-record-workflow.sh @@ -15,7 +15,7 @@ set -euo pipefail BRANCH="" TEST_SUBDIRS="" TEST_PROVIDER="ollama" -RUN_VISION_TESTS=false +TEST_SUITE="base" TEST_PATTERN="" # Help function @@ -27,9 +27,9 @@ Trigger the integration test recording workflow remotely. This way you do not ne OPTIONS: -b, --branch BRANCH Branch to run the workflow on (defaults to current branch) - -s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (REQUIRED) -p, --test-provider PROVIDER Test provider to use: vllm or ollama (default: ollama) - -v, --run-vision-tests Include vision tests in the recording + -t, --test-suite SUITE Test suite to use: base, responses, vision, etc. (default: base) + -s, --test-subdirs DIRS Comma-separated list of test subdirectories to run (overrides suite) -k, --test-pattern PATTERN Regex pattern to pass to pytest -k -h, --help Show this help message @@ -38,7 +38,7 @@ EXAMPLES: $0 --test-subdirs "agents" # Record tests for specific branch with vision tests - $0 -b my-feature-branch --test-subdirs "inference" --run-vision-tests + $0 -b my-feature-branch --test-suite vision # Record multiple test subdirectories with specific provider $0 --test-subdirs "agents,inference" --test-provider vllm @@ -71,9 +71,9 @@ while [[ $# -gt 0 ]]; do TEST_PROVIDER="$2" shift 2 ;; - -v|--run-vision-tests) - RUN_VISION_TESTS=true - shift + -t|--test-suite) + TEST_SUITE="$2" + shift 2 ;; -k|--test-pattern) TEST_PATTERN="$2" @@ -92,11 +92,11 @@ while [[ $# -gt 0 ]]; do done # Validate required parameters -if [[ -z "$TEST_SUBDIRS" ]]; then - echo "Error: --test-subdirs is required" - echo "Please specify which test subdirectories to run, e.g.:" +if [[ -z "$TEST_SUBDIRS" && -z "$TEST_SUITE" ]]; then + echo "Error: --test-subdirs or --test-suite is required" + echo "Please specify which test subdirectories to run or test suite to use, e.g.:" echo " $0 --test-subdirs \"agents,inference\"" - echo " $0 --test-subdirs \"inference\" --run-vision-tests" + echo " $0 --test-suite vision" echo "" exit 1 fi @@ -239,17 +239,19 @@ echo "Triggering integration test recording workflow..." echo "Branch: $BRANCH" echo "Test provider: $TEST_PROVIDER" echo "Test subdirs: $TEST_SUBDIRS" -echo "Run vision tests: $RUN_VISION_TESTS" +echo "Test suite: $TEST_SUITE" echo "Test pattern: ${TEST_PATTERN:-"(none)"}" echo "" # Prepare inputs for gh workflow run -INPUTS="-f test-subdirs='$TEST_SUBDIRS'" +if [[ -n "$TEST_SUBDIRS" ]]; then + INPUTS="-f test-subdirs='$TEST_SUBDIRS'" +fi if [[ -n "$TEST_PROVIDER" ]]; then INPUTS="$INPUTS -f test-provider='$TEST_PROVIDER'" fi -if [[ "$RUN_VISION_TESTS" == "true" ]]; then - INPUTS="$INPUTS -f run-vision-tests=true" +if [[ -n "$TEST_SUITE" ]]; then + INPUTS="$INPUTS -f test-suite='$TEST_SUITE'" fi if [[ -n "$TEST_PATTERN" ]]; then INPUTS="$INPUTS -f test-pattern='$TEST_PATTERN'" diff --git a/scripts/integration-tests.sh b/scripts/integration-tests.sh index 104ba5cf3..ab7e37579 100755 --- a/scripts/integration-tests.sh +++ b/scripts/integration-tests.sh @@ -16,7 +16,7 @@ STACK_CONFIG="" PROVIDER="" TEST_SUBDIRS="" TEST_PATTERN="" -RUN_VISION_TESTS="false" +TEST_SUITE="base" INFERENCE_MODE="replay" EXTRA_PARAMS="" @@ -28,12 +28,16 @@ Usage: $0 [OPTIONS] Options: --stack-config STRING Stack configuration to use (required) --provider STRING Provider to use (ollama, vllm, etc.) (required) - --test-subdirs STRING Comma-separated list of test subdirectories to run (default: 'inference') - --run-vision-tests Run vision tests instead of regular tests + --test-suite STRING Comma-separated list of test suites to run (default: 'base') --inference-mode STRING Inference mode: record or replay (default: replay) + --test-subdirs STRING Comma-separated list of test subdirectories to run (overrides suite) --test-pattern STRING Regex pattern to pass to pytest -k --help Show this help message +Suites are defined in tests/integration/suites.py. They are used to narrow the collection of tests and provide default model options. + +You can also specify subdirectories (of tests/integration) to select tests from, which will override the suite. + Examples: # Basic inference tests with ollama $0 --stack-config server:ci-tests --provider ollama @@ -42,7 +46,7 @@ Examples: $0 --stack-config server:ci-tests --provider vllm --test-subdirs 'inference,agents' # Vision tests with ollama - $0 --stack-config server:ci-tests --provider ollama --run-vision-tests + $0 --stack-config server:ci-tests --provider ollama --test-suite vision # Record mode for updating test recordings $0 --stack-config server:ci-tests --provider ollama --inference-mode record @@ -64,9 +68,9 @@ while [[ $# -gt 0 ]]; do TEST_SUBDIRS="$2" shift 2 ;; - --run-vision-tests) - RUN_VISION_TESTS="true" - shift + --test-suite) + TEST_SUITE="$2" + shift 2 ;; --inference-mode) INFERENCE_MODE="$2" @@ -92,22 +96,25 @@ done # Validate required parameters if [[ -z "$STACK_CONFIG" ]]; then echo "Error: --stack-config is required" - usage exit 1 fi if [[ -z "$PROVIDER" ]]; then echo "Error: --provider is required" - usage + exit 1 +fi + +if [[ -z "$TEST_SUITE" && -z "$TEST_SUBDIRS" ]]; then + echo "Error: --test-suite or --test-subdirs is required" exit 1 fi echo "=== Llama Stack Integration Test Runner ===" echo "Stack Config: $STACK_CONFIG" echo "Provider: $PROVIDER" -echo "Test Subdirs: $TEST_SUBDIRS" -echo "Vision Tests: $RUN_VISION_TESTS" echo "Inference Mode: $INFERENCE_MODE" +echo "Test Suite: $TEST_SUITE" +echo "Test Subdirs: $TEST_SUBDIRS" echo "Test Pattern: $TEST_PATTERN" echo "" @@ -194,84 +201,46 @@ if [[ -n "$TEST_PATTERN" ]]; then PYTEST_PATTERN="${PYTEST_PATTERN} and $TEST_PATTERN" fi -# Run vision tests if specified -if [[ "$RUN_VISION_TESTS" == "true" ]]; then - echo "Running vision tests..." - set +e - pytest -s -v tests/integration/inference/test_vision_inference.py \ - --stack-config="$STACK_CONFIG" \ - -k "$PYTEST_PATTERN" \ - --vision-model=ollama/llama3.2-vision:11b \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes $EXTRA_PARAMS \ - --capture=tee-sys - exit_code=$? - set -e - - if [ $exit_code -eq 0 ]; then - echo "✅ Vision tests completed successfully" - elif [ $exit_code -eq 5 ]; then - echo "⚠️ No vision tests collected (pattern matched no tests)" - else - echo "❌ Vision tests failed" - exit 1 - fi - exit 0 -fi - -# Run regular tests -if [[ -z "$TEST_SUBDIRS" ]]; then - TEST_SUBDIRS=$(find tests/integration -maxdepth 1 -mindepth 1 -type d | - sed 's|tests/integration/||' | - grep -Ev "^(__pycache__|fixtures|test_cases|recordings|non_ci|post_training)$" | - sort) -fi echo "Test subdirs to run: $TEST_SUBDIRS" -# Collect all test files for the specified test types -TEST_FILES="" -for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do - # Skip certain test types for vllm provider - if [[ "$PROVIDER" == "vllm" ]]; then - if [[ "$test_subdir" == "safety" ]] || [[ "$test_subdir" == "post_training" ]] || [[ "$test_subdir" == "tool_runtime" ]]; then - echo "Skipping $test_subdir for vllm provider" - continue +if [[ -n "$TEST_SUBDIRS" ]]; then + # Collect all test files for the specified test types + TEST_FILES="" + for test_subdir in $(echo "$TEST_SUBDIRS" | tr ',' '\n'); do + if [[ -d "tests/integration/$test_subdir" ]]; then + # Find all Python test files in this directory + test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") + if [[ -n "$test_files" ]]; then + TEST_FILES="$TEST_FILES $test_files" + echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files" + fi + else + echo "Warning: Directory tests/integration/$test_subdir does not exist" fi + done + + if [[ -z "$TEST_FILES" ]]; then + echo "No test files found for the specified test types" + exit 1 fi - if [[ "$STACK_CONFIG" != *"server:"* ]] && [[ "$test_subdir" == "batches" ]]; then - echo "Skipping $test_subdir for library client until types are supported" - continue - fi + echo "" + echo "=== Running all collected tests in a single pytest command ===" + echo "Total test files: $(echo $TEST_FILES | wc -w)" - if [[ -d "tests/integration/$test_subdir" ]]; then - # Find all Python test files in this directory - test_files=$(find tests/integration/$test_subdir -name "test_*.py" -o -name "*_test.py") - if [[ -n "$test_files" ]]; then - TEST_FILES="$TEST_FILES $test_files" - echo "Added test files from $test_subdir: $(echo $test_files | wc -w) files" - fi - else - echo "Warning: Directory tests/integration/$test_subdir does not exist" - fi -done - -if [[ -z "$TEST_FILES" ]]; then - echo "No test files found for the specified test types" - exit 1 + PYTEST_TARGET="$TEST_FILES" + EXTRA_PARAMS="$EXTRA_PARAMS --text-model=$TEXT_MODEL --embedding-model=sentence-transformers/all-MiniLM-L6-v2" +else + PYTEST_TARGET="tests/integration/" + EXTRA_PARAMS="$EXTRA_PARAMS --suite=$TEST_SUITE" fi -echo "" -echo "=== Running all collected tests in a single pytest command ===" -echo "Total test files: $(echo $TEST_FILES | wc -w)" - set +e -pytest -s -v $TEST_FILES \ +pytest -s -v $PYTEST_TARGET \ --stack-config="$STACK_CONFIG" \ -k "$PYTEST_PATTERN" \ - --text-model="$TEXT_MODEL" \ - --embedding-model=sentence-transformers/all-MiniLM-L6-v2 \ - --color=yes $EXTRA_PARAMS \ + $EXTRA_PARAMS \ + --color=yes \ --capture=tee-sys exit_code=$? set -e @@ -294,7 +263,13 @@ df -h # stop server if [[ "$STACK_CONFIG" == *"server:"* ]]; then echo "Stopping Llama Stack Server..." - kill $(lsof -i :8321 | awk 'NR>1 {print $2}') + pids=$(lsof -i :8321 | awk 'NR>1 {print $2}') + if [[ -n "$pids" ]]; then + echo "Killing Llama Stack Server processes: $pids" + kill -9 $pids + else + echo "No Llama Stack Server processes found ?!" + fi echo "Llama Stack Server stopped" fi diff --git a/tests/README.md b/tests/README.md index 81f025f86..c00829d3e 100644 --- a/tests/README.md +++ b/tests/README.md @@ -77,7 +77,7 @@ You must be careful when re-recording. CI workflows assume a specific setup for ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents,inference" # Record with vision tests enabled -./scripts/github/schedule-record-workflow.sh --test-subdirs "inference" --run-vision-tests +./scripts/github/schedule-record-workflow.sh --test-suite vision # Record with specific provider ./scripts/github/schedule-record-workflow.sh --test-subdirs "agents" --test-provider vllm diff --git a/tests/integration/README.md b/tests/integration/README.md index d177cbebf..b05beeb98 100644 --- a/tests/integration/README.md +++ b/tests/integration/README.md @@ -42,6 +42,27 @@ Model parameters can be influenced by the following options: Each of these are comma-separated lists and can be used to generate multiple parameter combinations. Note that tests will be skipped if no model is specified. +### Suites (fast selection + sane defaults) + +- `--suite`: comma-separated list of named suites that both narrow which tests are collected and prefill common model options (unless you pass them explicitly). +- Available suites: + - `responses`: collects tests under `tests/integration/responses`; this is a separate suite because it needs a strong tool-calling model. + - `vision`: collects only `tests/integration/inference/test_vision_inference.py`; defaults `--vision-model=ollama/llama3.2-vision:11b`, `--embedding-model=sentence-transformers/all-MiniLM-L6-v2`. +- Explicit flags always win. For example, `--suite=responses --text-model=` overrides the suite’s text model. + +Examples: + +```bash +# Fast responses run with defaults +pytest -s -v tests/integration --stack-config=server:starter --suite=responses + +# Fast single-file vision run with defaults +pytest -s -v tests/integration --stack-config=server:starter --suite=vision + +# Combine suites and override a default +pytest -s -v tests/integration --stack-config=server:starter --suite=responses,vision --embedding-model=text-embedding-3-small +``` + ## Examples ### Testing against a Server diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index fd9a54d04..96260fdb7 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -6,15 +6,17 @@ import inspect import itertools import os -import platform import textwrap import time +from pathlib import Path import pytest from dotenv import load_dotenv from llama_stack.log import get_logger +from .suites import SUITE_DEFINITIONS + logger = get_logger(__name__, category="tests") @@ -61,9 +63,22 @@ def pytest_configure(config): key, value = env_var.split("=", 1) os.environ[key] = value - if platform.system() == "Darwin": # Darwin is the system name for macOS - os.environ["DISABLE_CODE_SANDBOX"] = "1" - logger.info("Setting DISABLE_CODE_SANDBOX=1 for macOS") + suites_raw = config.getoption("--suite") + suites: list[str] = [] + if suites_raw: + suites = [p.strip() for p in str(suites_raw).split(",") if p.strip()] + unknown = [p for p in suites if p not in SUITE_DEFINITIONS] + if unknown: + raise pytest.UsageError( + f"Unknown suite(s): {', '.join(unknown)}. Available: {', '.join(sorted(SUITE_DEFINITIONS.keys()))}" + ) + for suite in suites: + suite_def = SUITE_DEFINITIONS.get(suite, {}) + defaults: dict = suite_def.get("defaults", {}) + for dest, value in defaults.items(): + current = getattr(config.option, dest, None) + if not current: + setattr(config.option, dest, value) def pytest_addoption(parser): @@ -105,16 +120,21 @@ def pytest_addoption(parser): default=384, help="Output dimensionality of the embedding model to use for testing. Default: 384", ) - parser.addoption( - "--record-responses", - action="store_true", - help="Record new API responses instead of using cached ones.", - ) parser.addoption( "--report", help="Path where the test report should be written, e.g. --report=/path/to/report.md", ) + available_suites = ", ".join(sorted(SUITE_DEFINITIONS.keys())) + suite_help = ( + "Comma-separated integration test suites to narrow collection and prefill defaults. " + "Available: " + f"{available_suites}. " + "Explicit CLI flags (e.g., --text-model) override suite defaults. " + "Examples: --suite=responses or --suite=responses,vision." + ) + parser.addoption("--suite", help=suite_help) + MODEL_SHORT_IDS = { "meta-llama/Llama-3.2-3B-Instruct": "3B", @@ -197,3 +217,40 @@ def pytest_generate_tests(metafunc): pytest_plugins = ["tests.integration.fixtures.common"] + + +def pytest_ignore_collect(path: str, config: pytest.Config) -> bool: + """Skip collecting paths outside the selected suite roots for speed.""" + suites_raw = config.getoption("--suite") + if not suites_raw: + return False + + names = [p.strip() for p in str(suites_raw).split(",") if p.strip()] + roots: list[str] = [] + for name in names: + suite_def = SUITE_DEFINITIONS.get(name) + if suite_def: + roots.extend(suite_def.get("roots", [])) + if not roots: + return False + + p = Path(str(path)).resolve() + + # Only constrain within tests/integration to avoid ignoring unrelated tests + integration_root = (Path(str(config.rootpath)) / "tests" / "integration").resolve() + if not p.is_relative_to(integration_root): + return False + + for r in roots: + rp = (Path(str(config.rootpath)) / r).resolve() + if rp.is_file(): + # Allow the exact file and any ancestor directories so pytest can walk into it. + if p == rp: + return False + if p.is_dir() and rp.is_relative_to(p): + return False + else: + # Allow anything inside an allowed directory + if p.is_relative_to(rp): + return False + return True diff --git a/tests/integration/non_ci/responses/__init__.py b/tests/integration/responses/__init__.py similarity index 100% rename from tests/integration/non_ci/responses/__init__.py rename to tests/integration/responses/__init__.py diff --git a/tests/integration/non_ci/responses/fixtures/__init__.py b/tests/integration/responses/fixtures/__init__.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/__init__.py rename to tests/integration/responses/fixtures/__init__.py diff --git a/tests/integration/non_ci/responses/fixtures/fixtures.py b/tests/integration/responses/fixtures/fixtures.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/fixtures.py rename to tests/integration/responses/fixtures/fixtures.py diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_1.jpg b/tests/integration/responses/fixtures/images/vision_test_1.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_1.jpg rename to tests/integration/responses/fixtures/images/vision_test_1.jpg diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_2.jpg b/tests/integration/responses/fixtures/images/vision_test_2.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_2.jpg rename to tests/integration/responses/fixtures/images/vision_test_2.jpg diff --git a/tests/integration/non_ci/responses/fixtures/images/vision_test_3.jpg b/tests/integration/responses/fixtures/images/vision_test_3.jpg similarity index 100% rename from tests/integration/non_ci/responses/fixtures/images/vision_test_3.jpg rename to tests/integration/responses/fixtures/images/vision_test_3.jpg diff --git a/tests/integration/non_ci/responses/fixtures/pdfs/llama_stack_and_models.pdf b/tests/integration/responses/fixtures/pdfs/llama_stack_and_models.pdf similarity index 100% rename from tests/integration/non_ci/responses/fixtures/pdfs/llama_stack_and_models.pdf rename to tests/integration/responses/fixtures/pdfs/llama_stack_and_models.pdf diff --git a/tests/integration/non_ci/responses/fixtures/test_cases.py b/tests/integration/responses/fixtures/test_cases.py similarity index 100% rename from tests/integration/non_ci/responses/fixtures/test_cases.py rename to tests/integration/responses/fixtures/test_cases.py diff --git a/tests/integration/non_ci/responses/helpers.py b/tests/integration/responses/helpers.py similarity index 100% rename from tests/integration/non_ci/responses/helpers.py rename to tests/integration/responses/helpers.py diff --git a/tests/integration/non_ci/responses/streaming_assertions.py b/tests/integration/responses/streaming_assertions.py similarity index 100% rename from tests/integration/non_ci/responses/streaming_assertions.py rename to tests/integration/responses/streaming_assertions.py diff --git a/tests/integration/non_ci/responses/test_basic_responses.py b/tests/integration/responses/test_basic_responses.py similarity index 100% rename from tests/integration/non_ci/responses/test_basic_responses.py rename to tests/integration/responses/test_basic_responses.py diff --git a/tests/integration/non_ci/responses/test_file_search.py b/tests/integration/responses/test_file_search.py similarity index 100% rename from tests/integration/non_ci/responses/test_file_search.py rename to tests/integration/responses/test_file_search.py diff --git a/tests/integration/non_ci/responses/test_tool_responses.py b/tests/integration/responses/test_tool_responses.py similarity index 100% rename from tests/integration/non_ci/responses/test_tool_responses.py rename to tests/integration/responses/test_tool_responses.py diff --git a/tests/integration/suites.py b/tests/integration/suites.py new file mode 100644 index 000000000..602855055 --- /dev/null +++ b/tests/integration/suites.py @@ -0,0 +1,53 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Central definition of integration test suites. You can use these suites by passing --suite=name to pytest. +# For example: +# +# ```bash +# pytest tests/integration/ --suite=vision +# ``` +# +# Each suite can: +# - restrict collection to specific roots (dirs or files) +# - provide default CLI option values (e.g. text_model, embedding_model, etc.) + +from pathlib import Path + +this_dir = Path(__file__).parent +default_roots = [ + str(p) + for p in this_dir.glob("*") + if p.is_dir() + and p.name not in ("__pycache__", "fixtures", "test_cases", "recordings", "responses", "post_training") +] + +SUITE_DEFINITIONS: dict[str, dict] = { + "base": { + "description": "Base suite that includes most tests but runs them with a text Ollama model", + "roots": default_roots, + "defaults": { + "text_model": "ollama/llama3.2:3b-instruct-fp16", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, + "responses": { + "description": "Suite that includes only the OpenAI Responses tests; needs a strong tool-calling model", + "roots": ["tests/integration/responses"], + "defaults": { + "text_model": "openai/gpt-4o", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, + "vision": { + "description": "Suite that includes only the vision tests", + "roots": ["tests/integration/inference/test_vision_inference.py"], + "defaults": { + "vision_model": "ollama/llama3.2-vision:11b", + "embedding_model": "sentence-transformers/all-MiniLM-L6-v2", + }, + }, +} From 7cd1c2c238969cb41ddb3d87b7dcfe5331350be1 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Sat, 6 Sep 2025 07:26:34 -0600 Subject: [PATCH 14/17] feat: Updating Rag Tool to use Files API and Vector Stores API (#3344) --- docs/source/getting_started/demo_script.py | 5 +- .../inline/tool_runtime/rag/__init__.py | 2 +- .../inline/tool_runtime/rag/memory.py | 74 ++++++++++++++----- .../providers/registry/tool_runtime.py | 2 +- .../integration/tool_runtime/test_rag_tool.py | 41 ++++++---- tests/unit/rag/test_rag_query.py | 8 +- 6 files changed, 93 insertions(+), 39 deletions(-) diff --git a/docs/source/getting_started/demo_script.py b/docs/source/getting_started/demo_script.py index 777fc78c2..2ea67739f 100644 --- a/docs/source/getting_started/demo_script.py +++ b/docs/source/getting_started/demo_script.py @@ -18,12 +18,13 @@ embedding_model_id = ( ).identifier embedding_dimension = em.metadata["embedding_dimension"] -_ = client.vector_dbs.register( +vector_db = client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, provider_id="faiss", ) +vector_db_id = vector_db.identifier source = "https://www.paulgraham.com/greatwork.html" print("rag_tool> Ingesting document:", source) document = RAGDocument( @@ -35,7 +36,7 @@ document = RAGDocument( client.tool_runtime.rag_tool.insert( documents=[document], vector_db_id=vector_db_id, - chunk_size_in_tokens=50, + chunk_size_in_tokens=100, ) agent = Agent( client, diff --git a/llama_stack/providers/inline/tool_runtime/rag/__init__.py b/llama_stack/providers/inline/tool_runtime/rag/__init__.py index f9a6e5c55..f9a7e7b89 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/__init__.py +++ b/llama_stack/providers/inline/tool_runtime/rag/__init__.py @@ -14,6 +14,6 @@ from .config import RagToolRuntimeConfig async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]): from .memory import MemoryToolRuntimeImpl - impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference]) + impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a1543457b..cb526e8ee 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -5,10 +5,15 @@ # the root directory of this source tree. import asyncio +import base64 +import io +import mimetypes import secrets import string from typing import Any +import httpx +from fastapi import UploadFile from pydantic import TypeAdapter from llama_stack.apis.common.content_types import ( @@ -17,6 +22,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContentItem, TextContentItem, ) +from llama_stack.apis.files import Files, OpenAIFilePurpose from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( ListToolDefsResponse, @@ -30,13 +36,18 @@ from llama_stack.apis.tools import ( ToolParameter, ToolRuntime, ) -from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorIO, + VectorStoreChunkingStrategyStatic, + VectorStoreChunkingStrategyStaticConfig, +) from llama_stack.log import get_logger from llama_stack.providers.datatypes import ToolGroupsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, - make_overlapped_chunks, + parse_data_url, ) from .config import RagToolRuntimeConfig @@ -55,10 +66,12 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti config: RagToolRuntimeConfig, vector_io_api: VectorIO, inference_api: Inference, + files_api: Files, ): self.config = config self.vector_io_api = vector_io_api self.inference_api = inference_api + self.files_api = files_api async def initialize(self): pass @@ -78,27 +91,50 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti vector_db_id: str, chunk_size_in_tokens: int = 512, ) -> None: - chunks = [] + if not documents: + return + for doc in documents: - content = await content_from_doc(doc) - # TODO: we should add enrichment here as URLs won't be added to the metadata by default - chunks.extend( - make_overlapped_chunks( - doc.document_id, - content, - chunk_size_in_tokens, - chunk_size_in_tokens // 4, - doc.metadata, + if isinstance(doc.content, URL): + if doc.content.uri.startswith("data:"): + parts = parse_data_url(doc.content.uri) + file_data = base64.b64decode(parts["data"]) if parts["is_base64"] else parts["data"].encode() + mime_type = parts["mimetype"] + else: + async with httpx.AsyncClient() as client: + response = await client.get(doc.content.uri) + file_data = response.content + mime_type = doc.mime_type or response.headers.get("content-type", "application/octet-stream") + else: + content_str = await content_from_doc(doc) + file_data = content_str.encode("utf-8") + mime_type = doc.mime_type or "text/plain" + + file_extension = mimetypes.guess_extension(mime_type) or ".txt" + filename = doc.metadata.get("filename", f"{doc.document_id}{file_extension}") + + file_obj = io.BytesIO(file_data) + file_obj.name = filename + + upload_file = UploadFile(file=file_obj, filename=filename) + + created_file = await self.files_api.openai_upload_file( + file=upload_file, purpose=OpenAIFilePurpose.ASSISTANTS + ) + + chunking_strategy = VectorStoreChunkingStrategyStatic( + static=VectorStoreChunkingStrategyStaticConfig( + max_chunk_size_tokens=chunk_size_in_tokens, + chunk_overlap_tokens=chunk_size_in_tokens // 4, ) ) - if not chunks: - return - - await self.vector_io_api.insert_chunks( - chunks=chunks, - vector_db_id=vector_db_id, - ) + await self.vector_io_api.openai_attach_file_to_vector_store( + vector_store_id=vector_db_id, + file_id=created_file.id, + attributes=doc.metadata, + chunking_strategy=chunking_strategy, + ) async def query( self, diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 661851443..5a58fa7af 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -32,7 +32,7 @@ def available_providers() -> list[ProviderSpec]: ], module="llama_stack.providers.inline.tool_runtime.rag", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", - api_dependencies=[Api.vector_io, Api.inference], + api_dependencies=[Api.vector_io, Api.inference, Api.files], description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", ), remote_provider_spec( diff --git a/tests/integration/tool_runtime/test_rag_tool.py b/tests/integration/tool_runtime/test_rag_tool.py index 2affe2a2d..b208500d8 100644 --- a/tests/integration/tool_runtime/test_rag_tool.py +++ b/tests/integration/tool_runtime/test_rag_tool.py @@ -17,10 +17,14 @@ def client_with_empty_registry(client_with_models): client_with_models.vector_dbs.unregister(vector_db_id=vector_db_id) clear_registry() + + try: + client_with_models.toolgroups.register(toolgroup_id="builtin::rag", provider_id="rag-runtime") + except Exception: + pass + yield client_with_models - # you must clean after the last test if you were running tests against - # a stateful server instance clear_registry() @@ -66,12 +70,13 @@ def assert_valid_text_response(response): def test_vector_db_insert_inline_and_query( client_with_empty_registry, sample_documents, embedding_model_id, embedding_dimension ): - vector_db_id = "test_vector_db" - client_with_empty_registry.vector_dbs.register( - vector_db_id=vector_db_id, + vector_db_name = "test_vector_db" + vector_db = client_with_empty_registry.vector_dbs.register( + vector_db_id=vector_db_name, embedding_model=embedding_model_id, embedding_dimension=embedding_dimension, ) + vector_db_id = vector_db.identifier client_with_empty_registry.tool_runtime.rag_tool.insert( documents=sample_documents, @@ -134,7 +139,11 @@ def test_vector_db_insert_from_url_and_query( # list to check memory bank is successfully registered available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -153,13 +162,13 @@ def test_vector_db_insert_from_url_and_query( client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) # Query for the name of method response1 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="What's the name of the fine-tunning method used?", ) assert_valid_chunk_response(response1) @@ -167,7 +176,7 @@ def test_vector_db_insert_from_url_and_query( # Query for the name of model response2 = client_with_empty_registry.vector_io.query( - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, query="Which Llama model is mentioned?", ) assert_valid_chunk_response(response2) @@ -187,7 +196,11 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i ) available_vector_dbs = [vector_db.identifier for vector_db in client_with_empty_registry.vector_dbs.list()] - assert vector_db_id in available_vector_dbs + # VectorDB is being migrated to VectorStore, so the ID will be different + # Just check that at least one vector DB was registered + assert len(available_vector_dbs) > 0 + # Use the actual registered vector_db_id for subsequent operations + actual_vector_db_id = available_vector_dbs[0] urls = [ "memory_optimizations.rst", @@ -206,19 +219,19 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i client_with_empty_registry.tool_runtime.rag_tool.insert( documents=documents, - vector_db_id=vector_db_id, + vector_db_id=actual_vector_db_id, chunk_size_in_tokens=512, ) response_with_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", ) assert_valid_text_response(response_with_metadata) assert any("metadata:" in chunk.text.lower() for chunk in response_with_metadata.content) response_without_metadata = client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "include_metadata_in_content": True, @@ -230,7 +243,7 @@ def test_rag_tool_insert_and_query(client_with_empty_registry, embedding_model_i with pytest.raises((ValueError, BadRequestError)): client_with_empty_registry.tool_runtime.rag_tool.query( - vector_db_ids=[vector_db_id], + vector_db_ids=[actual_vector_db_id], content="What is the name of the method used for fine-tuning?", query_config={ "chunk_template": "This should raise a ValueError because it is missing the proper template variables", diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index 05ccecb99..d18d90716 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -19,12 +19,16 @@ from llama_stack.providers.inline.tool_runtime.rag.memory import MemoryToolRunti class TestRagQuery: async def test_query_raises_on_empty_vector_db_ids(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) with pytest.raises(ValueError): await rag_tool.query(content=MagicMock(), vector_db_ids=[]) async def test_query_chunk_metadata_handling(self): - rag_tool = MemoryToolRuntimeImpl(config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock()) + rag_tool = MemoryToolRuntimeImpl( + config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock() + ) content = "test query content" vector_db_ids = ["db1"] From d6c3b363904eedd9c5323594603dc4f2d5b581eb Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 15:22:20 -0400 Subject: [PATCH 15/17] chore: update the gemini inference impl to use openai-python for openai-compat functions (#3351) # What does this PR do? update the Gemini inference provider to use openai-python for the openai-compat endpoints partially addresses #3349, does not address /inference/completion or /inference/chat-completion ## Test Plan ci --- llama_stack/providers/registry/inference.py | 2 +- llama_stack/providers/remote/inference/gemini/gemini.py | 8 +++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 50956f58c..1bb3c3147 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -207,7 +207,7 @@ def available_providers() -> list[ProviderSpec]: api=Api.inference, adapter=AdapterSpec( adapter_type="gemini", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.gemini", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/gemini/gemini.py b/llama_stack/providers/remote/inference/gemini/gemini.py index b6048eff7..569227fdd 100644 --- a/llama_stack/providers/remote/inference/gemini/gemini.py +++ b/llama_stack/providers/remote/inference/gemini/gemini.py @@ -5,12 +5,13 @@ # the root directory of this source tree. from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import GeminiConfig from .models import MODEL_ENTRIES -class GeminiInferenceAdapter(LiteLLMOpenAIMixin): +class GeminiInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): def __init__(self, config: GeminiConfig) -> None: LiteLLMOpenAIMixin.__init__( self, @@ -21,6 +22,11 @@ class GeminiInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self): + return "https://generativelanguage.googleapis.com/v1beta/openai/" + async def initialize(self) -> None: await super().initialize() From 4c28544c04ab96c5c4d188dda3a53dc2eab0415d Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 15:22:44 -0400 Subject: [PATCH 16/17] chore(gemini, tests): add skips for n and completions, gemini api does not support them (#3350) # What does this PR do? the gemini api endpoints do not support the n param or completions ## Test Plan ci --- tests/integration/inference/test_openai_completion.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index bb447b3c1..099032578 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -37,6 +37,7 @@ def skip_if_model_doesnt_support_openai_completion(client_with_models, model_id) "remote::sambanova", "remote::tgi", "remote::vertexai", + "remote::gemini", # https://generativelanguage.googleapis.com/v1beta/openai/completions -> 404 ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support OpenAI completions.") @@ -63,6 +64,9 @@ def skip_if_doesnt_support_n(client_with_models, model_id): if provider.provider_type in ( "remote::sambanova", "remote::ollama", + # Error code: 400 - [{'error': {'code': 400, 'message': 'Only one candidate can be specified in the + # current model', 'status': 'INVALID_ARGUMENT'}}] + "remote::gemini", ): pytest.skip(f"Model {model_id} hosted by {provider.provider_type} doesn't support n param.") From bf02cd846fdc39db80291746e06ca547e5afdbdb Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Sat, 6 Sep 2025 15:25:13 -0400 Subject: [PATCH 17/17] chore: update the sambanova inference impl to use openai-python for openai-compat functions (#3345) # What does this PR do? update SambaNova inference provider to use OpenAIMixin for openai-compat endpoints ## Test Plan ``` $ SAMBANOVA_API_KEY=... uv run llama stack build --image-type venv --providers inference=remote::sambanova --run ... $ LLAMA_STACK_CONFIG=http://localhost:8321 uv run --group test pytest -v -ra --text-model sambanova/Meta-Llama-3.3-70B-Instruct tests/integration/inference -k 'not store' ... FAILED tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=sambanova/Meta-Llama-3.3-70B-Instruct-inference:chat_completion:tool_calling_tools_absent-True] - AttributeError: 'NoneType' object has no attribute 'delta' FAILED tests/integration/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[txt=sambanova/Meta-Llama-3.3-70B-Instruct-inference:chat_completion:tool_calling_tools_absent-False] - llama_stack_client.InternalServerError: Error code: 500 - {'detail': 'Internal server error: An une... =========== 2 failed, 16 passed, 68 skipped, 8 deselected, 3 xfailed, 13 warnings in 15.85s ============ ``` the two failures also exist before this change. they are part of the deprecated inference.chat_completion tests that flow through litellm. they can be resolved later. --- llama_stack/providers/registry/inference.py | 2 +- .../remote/inference/sambanova/sambanova.py | 26 ++++++++++++++++++- 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 1bb3c3147..7a95fd089 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -270,7 +270,7 @@ Available Models: api=Api.inference, adapter=AdapterSpec( adapter_type="sambanova", - pip_packages=["litellm"], + pip_packages=["litellm", "openai"], module="llama_stack.providers.remote.inference.sambanova", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 96469acac..ee3b0f648 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -4,13 +4,26 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin +from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin from .config import SambaNovaImplConfig from .models import MODEL_ENTRIES -class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): +class SambaNovaInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): + """ + SambaNova Inference Adapter for Llama Stack. + + Note: The inheritance order is important here. OpenAIMixin must come before + LiteLLMOpenAIMixin to ensure that OpenAIMixin.check_model_availability() + is used instead of LiteLLMOpenAIMixin.check_model_availability(). + + - OpenAIMixin.check_model_availability() queries the /v1/models to check if a model exists + - LiteLLMOpenAIMixin.check_model_availability() checks the static registry within LiteLLM + """ + def __init__(self, config: SambaNovaImplConfig): self.config = config self.environment_available_models = [] @@ -24,3 +37,14 @@ class SambaNovaInferenceAdapter(LiteLLMOpenAIMixin): download_images=True, # SambaNova requires base64 image encoding json_schema_strict=False, # SambaNova doesn't support strict=True yet ) + + # Delegate the client data handling get_api_key method to LiteLLMOpenAIMixin + get_api_key = LiteLLMOpenAIMixin.get_api_key + + def get_base_url(self) -> str: + """ + Get the base URL for OpenAI mixin. + + :return: The SambaNova base URL + """ + return self.config.url