From b2c7543af72d80103fac4e1e8ff41302e59bfe57 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Mon, 21 Jul 2025 14:03:40 +0200 Subject: [PATCH 01/10] fix(vectordb): VectorDBInput has no provider_id (#2830) # What does this PR do? This PR add `provider_id` field to `VectorDBInput` class. fixes https://github.com/meta-llama/llama-stack/issues/2819 Signed-off-by: Mustafa Elbehery --- llama_stack/apis/vector_dbs/vector_dbs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 0d160737a..325e21bab 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -34,6 +34,7 @@ class VectorDBInput(BaseModel): vector_db_id: str embedding_model: str embedding_dimension: int + provider_id: str | None = None provider_vector_db_id: str | None = None From 89c49eb003cfc7b70babf85bab81bc3ebc81b63b Mon Sep 17 00:00:00 2001 From: Ondrej Metelka Date: Mon, 21 Jul 2025 15:43:32 +0200 Subject: [PATCH 02/10] feat: Allow application/yaml as mime_type (#2575) # What does this PR do? Allow application/yaml as mime_type for documents. ## Test Plan Added unit tests. --- .../agents/meta_reference/agent_instance.py | 11 +- .../agent/test_get_raw_document_text.py | 187 ++++++++++++++++++ 2 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 tests/unit/providers/agent/test_get_raw_document_text.py diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4d2b9f8bf..3c34c71fb 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -10,6 +10,7 @@ import re import secrets import string import uuid +import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime @@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str: async def get_raw_document_text(document: Document) -> str: - if not document.mime_type.startswith("text/"): + # Handle deprecated text/yaml mime type with warning + if document.mime_type == "text/yaml": + warnings.warn( + "The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.", + DeprecationWarning, + stacklevel=2, + ) + elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"): raise ValueError(f"Unexpected document mime type: {document.mime_type}") + if isinstance(document.content, URL): return await load_data_from_url(document.content.uri) elif isinstance(document.content, str): diff --git a/tests/unit/providers/agent/test_get_raw_document_text.py b/tests/unit/providers/agent/test_get_raw_document_text.py new file mode 100644 index 000000000..ddc886293 --- /dev/null +++ b/tests/unit/providers/agent/test_get_raw_document_text.py @@ -0,0 +1,187 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from llama_stack.apis.agents import Document +from llama_stack.apis.common.content_types import URL, TextContentItem +from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text + + +@pytest.mark.asyncio +async def test_get_raw_document_text_supports_text_mime_types(): + """Test that the function accepts text/* mime types.""" + document = Document(content="Sample text content", mime_type="text/plain") + + result = await get_raw_document_text(document) + assert result == "Sample text content" + + +@pytest.mark.asyncio +async def test_get_raw_document_text_supports_yaml_mime_type(): + """Test that the function accepts application/yaml mime type.""" + yaml_content = """ + name: test + version: 1.0 + items: + - item1 + - item2 + """ + + document = Document(content=yaml_content, mime_type="application/yaml") + + result = await get_raw_document_text(document) + assert result == yaml_content + + +@pytest.mark.asyncio +async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning(): + """Test that the function accepts text/yaml but emits a deprecation warning.""" + yaml_content = """ + name: test + version: 1.0 + items: + - item1 + - item2 + """ + + document = Document(content=yaml_content, mime_type="text/yaml") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await get_raw_document_text(document) + + # Check that result is correct + assert result == yaml_content + + # Check that exactly one warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "text/yaml" in str(w[0].message) + assert "application/yaml" in str(w[0].message) + assert "deprecated" in str(w[0].message).lower() + + +@pytest.mark.asyncio +async def test_get_raw_document_text_deprecated_text_yaml_with_url(): + """Test that text/yaml works with URL content and emits warning.""" + yaml_content = "name: test\nversion: 1.0" + + with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load: + mock_load.return_value = yaml_content + + document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await get_raw_document_text(document) + + # Check that result is correct + assert result == yaml_content + mock_load.assert_called_once_with("https://example.com/config.yaml") + + # Check that deprecation warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "text/yaml" in str(w[0].message) + + +@pytest.mark.asyncio +async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item(): + """Test that text/yaml works with TextContentItem and emits warning.""" + yaml_content = "key: value\nlist:\n - item1\n - item2" + + document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await get_raw_document_text(document) + + # Check that result is correct + assert result == yaml_content + + # Check that deprecation warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "text/yaml" in str(w[0].message) + + +@pytest.mark.asyncio +async def test_get_raw_document_text_rejects_unsupported_mime_types(): + """Test that the function rejects unsupported mime types.""" + document = Document( + content="Some content", + mime_type="application/json", # Not supported + ) + + with pytest.raises(ValueError, match="Unexpected document mime type: application/json"): + await get_raw_document_text(document) + + +@pytest.mark.asyncio +async def test_get_raw_document_text_with_url_content(): + """Test that the function handles URL content correctly.""" + mock_response = AsyncMock() + mock_response.text = "Content from URL" + + with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load: + mock_load.return_value = "Content from URL" + + document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain") + + result = await get_raw_document_text(document) + assert result == "Content from URL" + mock_load.assert_called_once_with("https://example.com/test.txt") + + +@pytest.mark.asyncio +async def test_get_raw_document_text_with_yaml_url(): + """Test that the function handles YAML URLs correctly.""" + yaml_content = "name: test\nversion: 1.0" + + with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load: + mock_load.return_value = yaml_content + + document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml") + + result = await get_raw_document_text(document) + assert result == yaml_content + mock_load.assert_called_once_with("https://example.com/config.yaml") + + +@pytest.mark.asyncio +async def test_get_raw_document_text_with_text_content_item(): + """Test that the function handles TextContentItem correctly.""" + document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain") + + result = await get_raw_document_text(document) + assert result == "Text content item" + + +@pytest.mark.asyncio +async def test_get_raw_document_text_with_yaml_text_content_item(): + """Test that the function handles YAML TextContentItem correctly.""" + yaml_content = "key: value\nlist:\n - item1\n - item2" + + document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml") + + result = await get_raw_document_text(document) + assert result == yaml_content + + +@pytest.mark.asyncio +async def test_get_raw_document_text_rejects_unexpected_content_type(): + """Test that the function rejects unexpected document content types.""" + # Create a mock document that bypasses Pydantic validation + mock_document = MagicMock(spec=Document) + mock_document.mime_type = "text/plain" + mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem) + + with pytest.raises(ValueError, match="Unexpected document content type: "): + await get_raw_document_text(mock_document) From 9e6860b9cf27afaf48115df48c1670d6c33c40ba Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Mon, 21 Jul 2025 17:14:34 +0100 Subject: [PATCH 03/10] fix: remove @pytest.mark.asyncio from test_get_raw_document_text.py (#2840) # What does this PR do? The pre-commit workflow was failing in the main branch and removing `@pytest.mark.asyncio `from `test_get_raw_document_text.py` fixed that. ## Test Plan --- .../providers/agent/test_get_raw_document_text.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/unit/providers/agent/test_get_raw_document_text.py b/tests/unit/providers/agent/test_get_raw_document_text.py index ddc886293..eb481c0d8 100644 --- a/tests/unit/providers/agent/test_get_raw_document_text.py +++ b/tests/unit/providers/agent/test_get_raw_document_text.py @@ -14,7 +14,6 @@ from llama_stack.apis.common.content_types import URL, TextContentItem from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text -@pytest.mark.asyncio async def test_get_raw_document_text_supports_text_mime_types(): """Test that the function accepts text/* mime types.""" document = Document(content="Sample text content", mime_type="text/plain") @@ -23,7 +22,6 @@ async def test_get_raw_document_text_supports_text_mime_types(): assert result == "Sample text content" -@pytest.mark.asyncio async def test_get_raw_document_text_supports_yaml_mime_type(): """Test that the function accepts application/yaml mime type.""" yaml_content = """ @@ -40,7 +38,6 @@ async def test_get_raw_document_text_supports_yaml_mime_type(): assert result == yaml_content -@pytest.mark.asyncio async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning(): """Test that the function accepts text/yaml but emits a deprecation warning.""" yaml_content = """ @@ -68,7 +65,6 @@ async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning( assert "deprecated" in str(w[0].message).lower() -@pytest.mark.asyncio async def test_get_raw_document_text_deprecated_text_yaml_with_url(): """Test that text/yaml works with URL content and emits warning.""" yaml_content = "name: test\nversion: 1.0" @@ -92,7 +88,6 @@ async def test_get_raw_document_text_deprecated_text_yaml_with_url(): assert "text/yaml" in str(w[0].message) -@pytest.mark.asyncio async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item(): """Test that text/yaml works with TextContentItem and emits warning.""" yaml_content = "key: value\nlist:\n - item1\n - item2" @@ -112,7 +107,6 @@ async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item assert "text/yaml" in str(w[0].message) -@pytest.mark.asyncio async def test_get_raw_document_text_rejects_unsupported_mime_types(): """Test that the function rejects unsupported mime types.""" document = Document( @@ -124,7 +118,6 @@ async def test_get_raw_document_text_rejects_unsupported_mime_types(): await get_raw_document_text(document) -@pytest.mark.asyncio async def test_get_raw_document_text_with_url_content(): """Test that the function handles URL content correctly.""" mock_response = AsyncMock() @@ -140,7 +133,6 @@ async def test_get_raw_document_text_with_url_content(): mock_load.assert_called_once_with("https://example.com/test.txt") -@pytest.mark.asyncio async def test_get_raw_document_text_with_yaml_url(): """Test that the function handles YAML URLs correctly.""" yaml_content = "name: test\nversion: 1.0" @@ -155,7 +147,6 @@ async def test_get_raw_document_text_with_yaml_url(): mock_load.assert_called_once_with("https://example.com/config.yaml") -@pytest.mark.asyncio async def test_get_raw_document_text_with_text_content_item(): """Test that the function handles TextContentItem correctly.""" document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain") @@ -164,7 +155,6 @@ async def test_get_raw_document_text_with_text_content_item(): assert result == "Text content item" -@pytest.mark.asyncio async def test_get_raw_document_text_with_yaml_text_content_item(): """Test that the function handles YAML TextContentItem correctly.""" yaml_content = "key: value\nlist:\n - item1\n - item2" @@ -175,7 +165,6 @@ async def test_get_raw_document_text_with_yaml_text_content_item(): assert result == yaml_content -@pytest.mark.asyncio async def test_get_raw_document_text_rejects_unexpected_content_type(): """Test that the function rejects unexpected document content types.""" # Create a mock document that bypasses Pydantic validation From d0208df286987dc78cd2f64d4326ae247e7b39d6 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 21 Jul 2025 10:01:40 -0700 Subject: [PATCH 04/10] test: skip flaky telemetry tests (#2814) # What does this PR do? example error: https://github.com/meta-llama/llama-stack/actions/runs/16368394907/job/46250869773 ## Test Plan --- tests/integration/telemetry/test_telemetry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/telemetry/test_telemetry.py b/tests/integration/telemetry/test_telemetry.py index 9df03da70..675dc780a 100644 --- a/tests/integration/telemetry/test_telemetry.py +++ b/tests/integration/telemetry/test_telemetry.py @@ -50,6 +50,7 @@ def setup_telemetry_data(llama_stack_client, text_model_id): yield +@pytest.mark.skip(reason="Skipping telemetry tests for now") def test_query_traces_basic(llama_stack_client): """Test basic trace querying functionality with proper data validation.""" all_traces = llama_stack_client.telemetry.query_traces(limit=5) @@ -105,6 +106,7 @@ def test_query_traces_basic(llama_stack_client): assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id" +@pytest.mark.skip(reason="Skipping telemetry tests for now") def test_query_spans_basic(llama_stack_client): """Test basic span querying functionality with proper validation.""" spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[]) @@ -153,6 +155,7 @@ def test_query_spans_basic(llama_stack_client): assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}" +@pytest.mark.skip(reason="Skipping telemetry tests for now") def test_telemetry_pagination(llama_stack_client): """Test pagination in telemetry queries.""" # Get total count of traces From 019ddda13845e816aaf728924fbe4232565e07cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 21 Jul 2025 20:35:15 +0200 Subject: [PATCH 05/10] fix: graceful SIGINT on server (#2831) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? After https://github.com/meta-llama/llama-stack/pull/2818, SIGINT will print a stack trace. This is because uvicorn re-raises SIGINT and it gets converted by Python internal signal handler (default handles SIGINT) to KeyboardInterrupt exception. We know simply catch the exception to get a clean exit, this is not changing the behavior on SIGINT. ## Test Plan Run the server, hit Ctrl+C or `kill -2 ` and expect a clean exit with no stack trace. Signed-off-by: Sébastien Han --- llama_stack/distribution/server/server.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e7e9e5e88..935688946 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -597,7 +597,23 @@ def main(args: argparse.Namespace | None = None): uvicorn_config.update(ssl_config) # Run uvicorn in the existing event loop to preserve background tasks - loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + # We need to catch KeyboardInterrupt because uvicorn's signal handling + # re-raises SIGINT signals using signal.raise_signal(), which Python + # converts to KeyboardInterrupt. Without this catch, we'd get a confusing + # stack trace when using Ctrl+C or kill -2 (SIGINT). + # SIGTERM (kill -15) works fine without this because Python doesn't + # have a default handler for it. + # + # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own + # signal handling but this is quite intrusive and not worth the effort. + try: + loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + except (KeyboardInterrupt, SystemExit): + logger.info("Received interrupt signal, shutting down gracefully...") + finally: + if not loop.is_closed(): + logger.debug("Closing event loop") + loop.close() def _log_run_config(run_config: StackRunConfig): From 9a03526672dd34e65940015d477db9ec6f4532ba Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Mon, 21 Jul 2025 15:50:39 -0400 Subject: [PATCH 06/10] fix: uvicorn respect log_config (#2842) --- llama_stack/distribution/server/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 935688946..ede65e8d6 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -592,6 +592,7 @@ def main(args: argparse.Namespace | None = None): "port": port, "lifespan": "on", "log_level": logger.getEffectiveLevel(), + "log_config": logger_config, } if ssl_config: uvicorn_config.update(ssl_config) From 0d7a90b8bc50f31acc12c244456a59f7bd0559e5 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 21 Jul 2025 13:19:27 -0700 Subject: [PATCH 07/10] chore: merge --config and --template in server.py (#2716) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Part of #2696 ## Test Plan Run `llama stack run starter` Error: ``` myenv ❯ llama stack run starters WARNING 2025-07-10 12:12:43,052 llama_stack.cli.stack.run:82 server: Conda detected. Using conda environment myenv for the run. usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] [--image-type {conda,venv}] [--enable-ui] [config | template] llama stack run: error: Could not resolve config or template 'starters'. Tried the following locations: 1. As file path: /Users/erichuang/projects/llama-stack-git/starters 2. As template: /Users/erichuang/projects/llama-stack-git/llama_stack/templates/starters/run.yaml 3. As built distribution: (/Users/erichuang/.llama/distributions/llamastack-starters/starters-run.yaml, /Users/erichuang/.llama/distributions/starters/starters-run.yaml) Available templates: dell, test-env, vllm-gpu, test-template, cerebras, openai-api-verification, sambanova, passthrough, direct-config, together, openai, fireworks, meta-reference-gpu, __pycache__, dev, ollama, watsonx, remote-vllm, llama_api, groq, dummy, oracle, nvidia, ci-tests, postgres-demo, test-stack, bedrock, starter, hf-serverless, hf-endpoint, tgi, open-benchmark, verification Did you mean one of these templates? - starter - together - postgres-demo ``` --- llama_stack/cli/stack/_build.py | 4 +- llama_stack/cli/stack/run.py | 52 ++------ llama_stack/cli/utils.py | 31 +++++ llama_stack/distribution/server/server.py | 36 +---- llama_stack/distribution/start_stack.sh | 2 +- .../distribution/utils/config_resolution.py | 125 ++++++++++++++++++ llama_stack/distribution/utils/exec.py | 2 +- 7 files changed, 176 insertions(+), 76 deletions(-) create mode 100644 llama_stack/cli/utils.py create mode 100644 llama_stack/distribution/utils/config_resolution.py diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index b573b2edc..3f94b1e2c 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None: config = parse_and_maybe_upgrade_config(config_dict) if config.external_providers_dir and not config.external_providers_dir.exists(): config.external_providers_dir.mkdir(exist_ok=True) - run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) - run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) + run_args = formulate_run_args(args.image_type, args.image_name) + run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)]) run_command(run_args) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index f4a119522..3cb2e213c 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -82,39 +82,6 @@ class StackRun(Subcommand): return ImageType.CONDA.value, args.image_name return args.image_type, args.image_name - def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]: - """Resolve config file path and template name from args.config""" - from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR - - if not args.config: - return None, None - - config_file = Path(args.config) - has_yaml_suffix = args.config.endswith(".yaml") - template_name = None - - if not config_file.exists() and not has_yaml_suffix: - # check if this is a template - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml" - if config_file.exists(): - template_name = args.config - - if not config_file.exists() and not has_yaml_suffix: - # check if it's a build config saved to ~/.llama dir - config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml") - - if not config_file.exists(): - self.parser.error( - f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file" - ) - - if not config_file.is_file(): - self.parser.error( - f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}" - ) - - return config_file, template_name - def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml @@ -125,8 +92,15 @@ class StackRun(Subcommand): self._start_ui_development_server(args.port) image_type, image_name = self._get_image_type_and_name(args) - # Resolve config file and template name first - config_file, template_name = self._resolve_config_and_template(args) + if args.config: + try: + from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template + + config_file = resolve_config_or_template(args.config, Mode.RUN) + except ValueError as e: + self.parser.error(str(e)) + else: + config_file = None # Check if config is required based on image type if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file: @@ -164,18 +138,14 @@ class StackRun(Subcommand): if callable(getattr(args, arg)): continue if arg == "config": - if template_name: - server_args.template = str(template_name) - else: - # Set the config file path - server_args.config = str(config_file) + server_args.config = str(config_file) else: setattr(server_args, arg, getattr(args, arg)) # Run the server server_main(server_args) else: - run_args = formulate_run_args(image_type, image_name, config, template_name) + run_args = formulate_run_args(image_type, image_name) run_args.extend([str(args.port)]) diff --git a/llama_stack/cli/utils.py b/llama_stack/cli/utils.py new file mode 100644 index 000000000..433627cc0 --- /dev/null +++ b/llama_stack/cli/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + + +def add_config_template_args(parser: argparse.ArgumentParser): + """Add unified config/template arguments with backward compatibility.""" + group = parser.add_mutually_exclusive_group(required=True) + + group.add_argument( + "config", + nargs="?", + help="Configuration file path or template name", + ) + + # Backward compatibility arguments (deprecated) + group.add_argument( + "--config", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Configuration file path", + ) + + group.add_argument( + "--template", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Template name", + ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index ede65e8d6..e58c28f2e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -32,6 +32,7 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.cli.utils import add_config_template_args from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.datatypes import ( AuthenticationRequiredError, @@ -53,6 +54,7 @@ from llama_stack.distribution.stack import ( validate_env_pair, ) from llama_stack.distribution.utils.config import redact_sensitive_fields +from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api @@ -377,20 +379,8 @@ class ClientVersionMiddleware: def main(args: argparse.Namespace | None = None): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") - parser.add_argument( - "--yaml-config", - dest="config", - help="(Deprecated) Path to YAML configuration file - use --config instead", - ) - parser.add_argument( - "--config", - dest="config", - help="Path to YAML configuration file", - ) - parser.add_argument( - "--template", - help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)", - ) + + add_config_template_args(parser) parser.add_argument( "--port", type=int, @@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - log_line = "" - if hasattr(args, "config") and args.config: - # if the user provided a config file, use it, even if template was specified - config_file = Path(args.config) - if not config_file.exists(): - raise ValueError(f"Config file {config_file} does not exist") - log_line = f"Using config file: {config_file}" - elif hasattr(args, "template") and args.template: - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" - if not config_file.exists(): - raise ValueError(f"Template {args.template} does not exist") - log_line = f"Using template {args.template} config file: {config_file}" - else: - raise ValueError("Either --config or --template must be provided") + config_file = resolve_config_or_template(args.config, Mode.RUN) logger_config = None with open(config_file) as fp: @@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None): config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) - # now that the logger is initialized, print the line about which type of config we are using. - logger.info(log_line) - _log_run_config(run_config=config) app = FastAPI( diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index 85bfceec4..77a7dc92e 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then set -x if [ -n "$yaml_config" ]; then - yaml_config_arg="--config $yaml_config" + yaml_config_arg="$yaml_config" else yaml_config_arg="" fi diff --git a/llama_stack/distribution/utils/config_resolution.py b/llama_stack/distribution/utils/config_resolution.py new file mode 100644 index 000000000..7e8de1242 --- /dev/null +++ b/llama_stack/distribution/utils/config_resolution.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import StrEnum +from pathlib import Path + +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="config_resolution") + + +TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates" + + +class Mode(StrEnum): + RUN = "run" + BUILD = "build" + + +def resolve_config_or_template( + config_or_template: str, + mode: Mode = Mode.RUN, +) -> Path: + """ + Resolve a config/template argument to a concrete config file path. + + Args: + config_or_template: User input (file path, template name, or built distribution) + mode: Mode resolving for ("run", "build", "server") + + Returns: + Path to the resolved config file + + Raises: + ValueError: If resolution fails + """ + + # Strategy 1: Try as file path first + config_path = Path(config_or_template) + if config_path.exists() and config_path.is_file(): + logger.info(f"Using file path: {config_path}") + return config_path.resolve() + + # Strategy 2: Try as template name (if no .yaml extension) + if not config_or_template.endswith(".yaml"): + template_config = _get_template_config_path(config_or_template, mode) + if template_config.exists(): + logger.info(f"Using template: {template_config}") + return template_config + + # Strategy 3: Try as built distribution name + distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + # Strategy 4: Failed - provide helpful error + raise ValueError(_format_resolution_error(config_or_template, mode)) + + +def _get_template_config_path(template_name: str, mode: Mode) -> Path: + """Get the config file path for a template.""" + return TEMPLATE_DIR / template_name / f"{mode}.yaml" + + +def _format_resolution_error(config_or_template: str, mode: Mode) -> str: + """Format a helpful error message for resolution failures.""" + from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + + template_path = _get_template_config_path(config_or_template, mode) + distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + + available_templates = _get_available_templates() + templates_str = ", ".join(available_templates) if available_templates else "none found" + + return f"""Could not resolve config or template '{config_or_template}'. + +Tried the following locations: + 1. As file path: {Path(config_or_template).resolve()} + 2. As template: {template_path} + 3. As built distribution: ({distrib_path}, {distrib_path2}) + +Available templates: {templates_str} + +Did you mean one of these templates? +{_format_template_suggestions(available_templates, config_or_template)} +""" + + +def _get_available_templates() -> list[str]: + """Get list of available template names.""" + if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists(): + return [] + + return list( + set( + [d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + + [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + ) + ) + + +def _format_template_suggestions(templates: list[str], user_input: str) -> str: + """Format template suggestions for error messages, showing closest matches first.""" + if not templates: + return " (no templates found)" + + import difflib + + # Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive) + close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3) + display_templates = close_matches if close_matches else templates[:3] + + suggestions = [f" - {t}" for t in display_templates] + return "\n".join(suggestions) diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index 2db01689f..c646ae821 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -21,7 +21,7 @@ from pathlib import Path from llama_stack.distribution.utils.image_types import LlamaStackImageType -def formulate_run_args(image_type, image_name, config, template_name) -> list: +def formulate_run_args(image_type: str, image_name: str) -> list[str]: env_name = "" if image_type == LlamaStackImageType.CONDA.value: From c8f274347d18acd9ed7e657ecf63d8acb0c5ba03 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Mon, 21 Jul 2025 16:22:44 -0400 Subject: [PATCH 08/10] chore: Adding Access Control for OpenAI Vector Stores methods (#2772) # What does this PR do? Refactors the vector store routing logic by moving OpenAI-compatible vector store operations from the `VectorIORouter` to the `VectorDBsRoutingTable`. Closes https://github.com/meta-llama/llama-stack/issues/2761 ## Test Plan Added unit tests to cover new routing logic and ACL checks. --------- Signed-off-by: Francisco Javier Arceo --- llama_stack/distribution/routers/vector_io.py | 43 +-- .../distribution/routing_tables/common.py | 15 + .../distribution/routing_tables/vector_dbs.py | 145 +++++++++ .../routers/test_routing_tables.py | 45 +-- .../distribution/routing_tables/__init__.py | 5 + .../routing_tables/test_vector_dbs.py | 274 ++++++++++++++++++ 6 files changed, 450 insertions(+), 77 deletions(-) create mode 100644 tests/unit/distribution/routing_tables/__init__.py create mode 100644 tests/unit/distribution/routing_tables/test_vector_dbs.py diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index cd56ada7b..a1dd66060 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -214,9 +214,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store(vector_store_id) + return await self.routing_table.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( self, @@ -226,9 +224,7 @@ class VectorIORouter(VectorIO): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store( + return await self.routing_table.openai_update_vector_store( vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -240,12 +236,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - result = await provider.openai_delete_vector_store(vector_store_id) - # drop from registry - await self.routing_table.unregister_vector_db(vector_store_id) - return result + return await self.routing_table.openai_delete_vector_store(vector_store_id) async def openai_search_vector_store( self, @@ -258,9 +249,7 @@ class VectorIORouter(VectorIO): search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_search_vector_store( + return await self.routing_table.openai_search_vector_store( vector_store_id=vector_store_id, query=query, filters=filters, @@ -278,9 +267,7 @@ class VectorIORouter(VectorIO): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_attach_file_to_vector_store( + return await self.routing_table.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -297,9 +284,7 @@ class VectorIORouter(VectorIO): filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_list_files_in_vector_store( + return await self.routing_table.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, order=order, @@ -314,9 +299,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file( + return await self.routing_table.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) @@ -327,9 +310,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileContentsResponse: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( + return await self.routing_table.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, ) @@ -341,9 +322,7 @@ class VectorIORouter(VectorIO): attributes: dict[str, Any], ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store_file( + return await self.routing_table.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -355,9 +334,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_delete_vector_store_file( + return await self.routing_table.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 7f7de32fe..bbe0113e9 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -9,6 +9,7 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.access_control.datatypes import Action from llama_stack.distribution.datatypes import ( AccessRule, RoutableObject, @@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable): await self.dist_registry.register(obj) return obj + async def assert_action_allowed( + self, + action: Action, + type: str, + identifier: str, + ) -> None: + """Fetch a registered object by type/identifier and enforce the given action permission.""" + obj = await self.get_object_by_identifier(type, identifier) + if obj is None: + raise ValueError(f"{type.capitalize()} '{identifier}' not found") + user = get_authenticated_user() + if not is_action_allowed(self.policy, action, obj, user): + raise AccessDeniedError(action, obj, user) + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() filtered_objs = [obj for obj in objs if obj.type == type] diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index f861102c8..b4e60c625 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -4,11 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any + from pydantic import TypeAdapter from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.apis.vector_io.vector_io import ( + SearchRankingOptions, + VectorStoreChunkingStrategy, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreFileStatus, + VectorStoreObject, + VectorStoreSearchResponsePage, +) from llama_stack.distribution.datatypes import ( VectorDBWithOwner, ) @@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if existing_vector_db is None: raise ValueError(f"Vector DB {vector_db_id} not found") await self.unregister_object(existing_vector_db) + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store( + vector_store_id=vector_store_id, + name=name, + expires_after=expires_after, + metadata=metadata, + ) + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) + await self.unregister_vector_db(vector_store_id) + return result + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int | None = 10, + ranking_options: SearchRankingOptions | None = None, + rewrite_query: bool | None = False, + search_mode: str | None = "vector", + ) -> VectorStoreSearchResponsePage: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=filters, + max_num_results=max_num_results, + ranking_options=ranking_options, + rewrite_query=rewrite_query, + search_mode=search_mode, + ) + + async def openai_attach_file_to_vector_store( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any] | None = None, + chunking_strategy: VectorStoreChunkingStrategy | None = None, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_list_files_in_vector_store( + self, + vector_store_id: str, + limit: int | None = 20, + order: str | None = "desc", + after: str | None = None, + before: str | None = None, + filter: VectorStoreFileStatus | None = None, + ) -> list[VectorStoreFileObject]: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( + vector_store_id=vector_store_id, + limit=limit, + order=order, + after=after, + before=before, + filter=filter, + ) + + async def openai_retrieve_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_retrieve_vector_store_file_contents( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileContentsResponse: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_update_vector_store_file( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any], + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + ) + + async def openai_delete_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 3ba042bd9..30f795d33 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,17 +11,15 @@ from unittest.mock import AsyncMock from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import Model from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter -from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable -from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: @@ -54,17 +52,6 @@ class SafetyImpl(Impl): return shield -class VectorDBImpl(Impl): - def __init__(self): - super().__init__(Api.vector_io) - - async def register_vector_db(self, vector_db: VectorDB): - return vector_db - - async def unregister_vector_db(self, vector_db_id: str): - return vector_db_id - - class DatasetsImpl(Impl): def __init__(self): super().__init__(Api.datasetio) @@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry): assert "test-shield-2" in shield_ids -async def test_vectordbs_routing_table(cached_disk_dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) - await table.initialize() - - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") - vector_dbs = await table.list_vector_dbs() - - assert len(vector_dbs.data) == 2 - vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids - - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") - - vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 0 - - async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() diff --git a/tests/unit/distribution/routing_tables/__init__.py b/tests/unit/distribution/routing_tables/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/distribution/routing_tables/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py new file mode 100644 index 000000000..28887e1cf --- /dev/null +++ b/tests/unit/distribution/routing_tables/test_vector_dbs.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Unit tests for the routing tables vector_dbs + +import time +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.datatypes import Api +from llama_stack.apis.models import ModelType +from llama_stack.apis.vector_dbs.vector_dbs import VectorDB +from llama_stack.apis.vector_io.vector_io import ( + VectorStoreContent, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileCounts, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreObject, + VectorStoreSearchResponsePage, +) +from llama_stack.distribution.access_control.datatypes import AccessRule, Scope +from llama_stack.distribution.datatypes import User +from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable +from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable + + +class VectorDBImpl(Impl): + def __init__(self): + super().__init__(Api.vector_io) + + async def register_vector_db(self, vector_db: VectorDB): + return vector_db + + async def unregister_vector_db(self, vector_db_id: str): + return vector_db_id + + async def openai_retrieve_vector_store(self, vector_store_id): + return VectorStoreObject( + id=vector_store_id, + name="Test Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_update_vector_store(self, vector_store_id, **kwargs): + return VectorStoreObject( + id=vector_store_id, + name="Updated Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_delete_vector_store(self, vector_store_id): + return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True) + + async def openai_search_vector_store(self, vector_store_id, query, **kwargs): + return VectorStoreSearchResponsePage( + object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None + ) + + async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs): + return [ + VectorStoreFileObject( + id="1", + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + ] + + async def openai_retrieve_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id): + return VectorStoreFileContentsResponse( + file_id=file_id, + filename="Sample File name", + attributes={"key": "value"}, + content=[VectorStoreContent(type="text", text="Sample content")], + ) + + async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_delete_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileDeleteResponse(id=file_id, deleted=True) + + +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + # Register multiple vector databases and verify listing + await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") + await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") + vector_dbs = await table.list_vector_dbs() + + assert len(vector_dbs.data) == 2 + vector_db_ids = {v.identifier for v in vector_dbs.data} + assert "test-vectordb" in vector_db_ids + assert "test-vectordb-2" in vector_db_ids + + await table.unregister_vector_db(vector_db_id="test-vectordb") + await table.unregister_vector_db(vector_db_id="test-vectordb-2") + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 0 + + +async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): + impl = VectorDBImpl() + impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[]) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + authorized_table = "vs1" + authorized_team = "team1" + unauthorized_team = "team2" + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) + with request_provider_data_context({}, authorized_user): + _ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + + # Authorized reader + with request_provider_data_context({}, authorized_user): + res = await table.openai_retrieve_vector_store(authorized_table) + assert res == "OK" + + # Authorized updater + impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + assert res == "UPDATED" + + # Unauthorized reader + unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]}) + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_retrieve_vector_store(authorized_table) + + # Unauthorized updater + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + + # Authorized deleter + impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + assert res == "DELETED" + + # Unauthorized deleter + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + + +async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry): + impl = VectorDBImpl() + + policy = [ + AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"), + AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"), + ] + + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + + vector_db_id = "vs1" + file_id = "file-1" + + admin_user = User(principal="admin", attributes={"roles": ["admin"]}) + read_only_user = User(principal="reader", attributes={"roles": ["reader"]}) + no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]}) + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + with request_provider_data_context({}, admin_user): + await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + + read_methods = [ + (table.openai_retrieve_vector_store, (vector_db_id,), {}), + (table.openai_search_vector_store, (vector_db_id, "query"), {}), + (table.openai_list_files_in_vector_store, (vector_db_id,), {}), + (table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}), + ] + update_methods = [ + (table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}), + (table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}), + (table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}), + ] + delete_methods = [ + (table.openai_delete_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_delete_vector_store, (vector_db_id,), {}), + ] + + for user in [admin_user, read_only_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in read_methods: + result = await method(*args, **kwargs) + assert result is not None, f"Read operation failed with user {user.principal}" + + with request_provider_data_context({}, no_access_user): + for method, args, kwargs in read_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs) + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in update_methods: + result = await method(*args, **kwargs) + assert result is not None, "Update operation failed with admin user" + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in delete_methods: + result = await method(*args, **kwargs) + assert result is not None, "Delete operation failed with admin user" + + for user in [read_only_user, no_access_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in delete_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs) From 2bc96613f918316a5df85925b0c7872127947cae Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Mon, 21 Jul 2025 22:53:32 -0400 Subject: [PATCH 09/10] chore: Adding demo script and importing it into the docs (#2848) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This PR adds the quickstart as a file to the docs so that it can be more easily maintained and run, as mentioned in https://github.com/meta-llama/llama-stack/pull/2800. ## Test Plan I could add this as a test in the CI but I wasn't sure if we wanted to add additional jobs there. 😅 Signed-off-by: Francisco Javier Arceo --- docs/source/getting_started/demo_script.py | 62 ++++++++++++++++++++++ docs/source/getting_started/quickstart.md | 59 +------------------- 2 files changed, 64 insertions(+), 57 deletions(-) create mode 100644 docs/source/getting_started/demo_script.py diff --git a/docs/source/getting_started/demo_script.py b/docs/source/getting_started/demo_script.py new file mode 100644 index 000000000..298fd9899 --- /dev/null +++ b/docs/source/getting_started/demo_script.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient + +vector_db_id = "my_demo_vector_db" +client = LlamaStackClient(base_url="http://localhost:8321") + +models = client.models.list() + +# Select the first LLM and first embedding models +model_id = next(m for m in models if m.model_type == "llm").identifier +embedding_model_id = ( + em := next(m for m in models if m.model_type == "embedding") +).identifier +embedding_dimension = em.metadata["embedding_dimension"] + +_ = client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=embedding_dimension, + provider_id="faiss", +) +source = "https://www.paulgraham.com/greatwork.html" +print("rag_tool> Ingesting document:", source) +document = RAGDocument( + document_id="document_1", + content=source, + mime_type="text/html", + metadata={}, +) +client.tool_runtime.rag_tool.insert( + documents=[document], + vector_db_id=vector_db_id, + chunk_size_in_tokens=50, +) +agent = Agent( + client, + model=model_id, + instructions="You are a helpful assistant", + tools=[ + { + "name": "builtin::rag/knowledge_search", + "args": {"vector_db_ids": [vector_db_id]}, + } + ], +) + +prompt = "How do you do great work?" +print("prompt>", prompt) + +response = agent.create_turn( + messages=[{"role": "user", "content": prompt}], + session_id=agent.create_session("rag_session"), + stream=True, +) + +for log in AgentEventLogger().log(response): + log.print() diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index 59791643d..5549f412c 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -24,63 +24,8 @@ ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stac #### Step 3: Run the demo Now open up a new terminal and copy the following script into a file named `demo_script.py`. -```python -from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient - -vector_db_id = "my_demo_vector_db" -client = LlamaStackClient(base_url="http://localhost:8321") - -models = client.models.list() - -# Select the first LLM and first embedding models -model_id = next(m for m in models if m.model_type == "llm").identifier -embedding_model_id = ( - em := next(m for m in models if m.model_type == "embedding") -).identifier -embedding_dimension = em.metadata["embedding_dimension"] - -_ = client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - provider_id="faiss", -) -source = "https://www.paulgraham.com/greatwork.html" -print("rag_tool> Ingesting document:", source) -document = RAGDocument( - document_id="document_1", - content=source, - mime_type="text/html", - metadata={}, -) -client.tool_runtime.rag_tool.insert( - documents=[document], - vector_db_id=vector_db_id, - chunk_size_in_tokens=50, -) -agent = Agent( - client, - model=model_id, - instructions="You are a helpful assistant", - tools=[ - { - "name": "builtin::rag/knowledge_search", - "args": {"vector_db_ids": [vector_db_id]}, - } - ], -) - -prompt = "How do you do great work?" -print("prompt>", prompt) - -response = agent.create_turn( - messages=[{"role": "user", "content": prompt}], - session_id=agent.create_session("rag_session"), - stream=True, -) - -for log in AgentEventLogger().log(response): - log.print() +```{literalinclude} ./demo_script.py +:language: python ``` We will use `uv` to run the script ``` From b5a6ecc331f14046a77a00f86b0dbfa6376ee0a5 Mon Sep 17 00:00:00 2001 From: Jeremy Bonghwan Choi Date: Tue, 22 Jul 2025 15:10:35 +1000 Subject: [PATCH 10/10] docs: minor fix of the pgvector provider spec description (#2847) # What does this PR do? minor update of the pgvector doc, changing 'faiss' to 'pgvector' ## Test Plan --- docs/source/providers/vector_io/remote_pgvector.md | 2 +- llama_stack/providers/registry/vector_io.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/providers/vector_io/remote_pgvector.md b/docs/source/providers/vector_io/remote_pgvector.md index 3e7d6e776..74f588a13 100644 --- a/docs/source/providers/vector_io/remote_pgvector.md +++ b/docs/source/providers/vector_io/remote_pgvector.md @@ -17,7 +17,7 @@ That means you'll get fast and efficient vector retrieval. To use PGVector in your Llama Stack project, follow these steps: 1. Install the necessary dependencies. -2. Configure your Llama Stack project to use Faiss. +2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector). 3. Start storing and querying vectors. ## Installation diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index c13e65bbc..e391341b4 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval. To use PGVector in your Llama Stack project, follow these steps: 1. Install the necessary dependencies. -2. Configure your Llama Stack project to use Faiss. +2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector). 3. Start storing and querying vectors. ## Installation