From 9e6860b9cf27afaf48115df48c1670d6c33c40ba Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Mon, 21 Jul 2025 17:14:34 +0100 Subject: [PATCH 1/6] fix: remove @pytest.mark.asyncio from test_get_raw_document_text.py (#2840) # What does this PR do? The pre-commit workflow was failing in the main branch and removing `@pytest.mark.asyncio `from `test_get_raw_document_text.py` fixed that. ## Test Plan --- .../providers/agent/test_get_raw_document_text.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/unit/providers/agent/test_get_raw_document_text.py b/tests/unit/providers/agent/test_get_raw_document_text.py index ddc886293..eb481c0d8 100644 --- a/tests/unit/providers/agent/test_get_raw_document_text.py +++ b/tests/unit/providers/agent/test_get_raw_document_text.py @@ -14,7 +14,6 @@ from llama_stack.apis.common.content_types import URL, TextContentItem from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text -@pytest.mark.asyncio async def test_get_raw_document_text_supports_text_mime_types(): """Test that the function accepts text/* mime types.""" document = Document(content="Sample text content", mime_type="text/plain") @@ -23,7 +22,6 @@ async def test_get_raw_document_text_supports_text_mime_types(): assert result == "Sample text content" -@pytest.mark.asyncio async def test_get_raw_document_text_supports_yaml_mime_type(): """Test that the function accepts application/yaml mime type.""" yaml_content = """ @@ -40,7 +38,6 @@ async def test_get_raw_document_text_supports_yaml_mime_type(): assert result == yaml_content -@pytest.mark.asyncio async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning(): """Test that the function accepts text/yaml but emits a deprecation warning.""" yaml_content = """ @@ -68,7 +65,6 @@ async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning( assert "deprecated" in str(w[0].message).lower() -@pytest.mark.asyncio async def test_get_raw_document_text_deprecated_text_yaml_with_url(): """Test that text/yaml works with URL content and emits warning.""" yaml_content = "name: test\nversion: 1.0" @@ -92,7 +88,6 @@ async def test_get_raw_document_text_deprecated_text_yaml_with_url(): assert "text/yaml" in str(w[0].message) -@pytest.mark.asyncio async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item(): """Test that text/yaml works with TextContentItem and emits warning.""" yaml_content = "key: value\nlist:\n - item1\n - item2" @@ -112,7 +107,6 @@ async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item assert "text/yaml" in str(w[0].message) -@pytest.mark.asyncio async def test_get_raw_document_text_rejects_unsupported_mime_types(): """Test that the function rejects unsupported mime types.""" document = Document( @@ -124,7 +118,6 @@ async def test_get_raw_document_text_rejects_unsupported_mime_types(): await get_raw_document_text(document) -@pytest.mark.asyncio async def test_get_raw_document_text_with_url_content(): """Test that the function handles URL content correctly.""" mock_response = AsyncMock() @@ -140,7 +133,6 @@ async def test_get_raw_document_text_with_url_content(): mock_load.assert_called_once_with("https://example.com/test.txt") -@pytest.mark.asyncio async def test_get_raw_document_text_with_yaml_url(): """Test that the function handles YAML URLs correctly.""" yaml_content = "name: test\nversion: 1.0" @@ -155,7 +147,6 @@ async def test_get_raw_document_text_with_yaml_url(): mock_load.assert_called_once_with("https://example.com/config.yaml") -@pytest.mark.asyncio async def test_get_raw_document_text_with_text_content_item(): """Test that the function handles TextContentItem correctly.""" document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain") @@ -164,7 +155,6 @@ async def test_get_raw_document_text_with_text_content_item(): assert result == "Text content item" -@pytest.mark.asyncio async def test_get_raw_document_text_with_yaml_text_content_item(): """Test that the function handles YAML TextContentItem correctly.""" yaml_content = "key: value\nlist:\n - item1\n - item2" @@ -175,7 +165,6 @@ async def test_get_raw_document_text_with_yaml_text_content_item(): assert result == yaml_content -@pytest.mark.asyncio async def test_get_raw_document_text_rejects_unexpected_content_type(): """Test that the function rejects unexpected document content types.""" # Create a mock document that bypasses Pydantic validation From d0208df286987dc78cd2f64d4326ae247e7b39d6 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 21 Jul 2025 10:01:40 -0700 Subject: [PATCH 2/6] test: skip flaky telemetry tests (#2814) # What does this PR do? example error: https://github.com/meta-llama/llama-stack/actions/runs/16368394907/job/46250869773 ## Test Plan --- tests/integration/telemetry/test_telemetry.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/integration/telemetry/test_telemetry.py b/tests/integration/telemetry/test_telemetry.py index 9df03da70..675dc780a 100644 --- a/tests/integration/telemetry/test_telemetry.py +++ b/tests/integration/telemetry/test_telemetry.py @@ -50,6 +50,7 @@ def setup_telemetry_data(llama_stack_client, text_model_id): yield +@pytest.mark.skip(reason="Skipping telemetry tests for now") def test_query_traces_basic(llama_stack_client): """Test basic trace querying functionality with proper data validation.""" all_traces = llama_stack_client.telemetry.query_traces(limit=5) @@ -105,6 +106,7 @@ def test_query_traces_basic(llama_stack_client): assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id" +@pytest.mark.skip(reason="Skipping telemetry tests for now") def test_query_spans_basic(llama_stack_client): """Test basic span querying functionality with proper validation.""" spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[]) @@ -153,6 +155,7 @@ def test_query_spans_basic(llama_stack_client): assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}" +@pytest.mark.skip(reason="Skipping telemetry tests for now") def test_telemetry_pagination(llama_stack_client): """Test pagination in telemetry queries.""" # Get total count of traces From 019ddda13845e816aaf728924fbe4232565e07cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 21 Jul 2025 20:35:15 +0200 Subject: [PATCH 3/6] fix: graceful SIGINT on server (#2831) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? After https://github.com/meta-llama/llama-stack/pull/2818, SIGINT will print a stack trace. This is because uvicorn re-raises SIGINT and it gets converted by Python internal signal handler (default handles SIGINT) to KeyboardInterrupt exception. We know simply catch the exception to get a clean exit, this is not changing the behavior on SIGINT. ## Test Plan Run the server, hit Ctrl+C or `kill -2 ` and expect a clean exit with no stack trace. Signed-off-by: Sébastien Han --- llama_stack/distribution/server/server.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e7e9e5e88..935688946 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -597,7 +597,23 @@ def main(args: argparse.Namespace | None = None): uvicorn_config.update(ssl_config) # Run uvicorn in the existing event loop to preserve background tasks - loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + # We need to catch KeyboardInterrupt because uvicorn's signal handling + # re-raises SIGINT signals using signal.raise_signal(), which Python + # converts to KeyboardInterrupt. Without this catch, we'd get a confusing + # stack trace when using Ctrl+C or kill -2 (SIGINT). + # SIGTERM (kill -15) works fine without this because Python doesn't + # have a default handler for it. + # + # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own + # signal handling but this is quite intrusive and not worth the effort. + try: + loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + except (KeyboardInterrupt, SystemExit): + logger.info("Received interrupt signal, shutting down gracefully...") + finally: + if not loop.is_closed(): + logger.debug("Closing event loop") + loop.close() def _log_run_config(run_config: StackRunConfig): From 9a03526672dd34e65940015d477db9ec6f4532ba Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Mon, 21 Jul 2025 15:50:39 -0400 Subject: [PATCH 4/6] fix: uvicorn respect log_config (#2842) --- llama_stack/distribution/server/server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 935688946..ede65e8d6 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -592,6 +592,7 @@ def main(args: argparse.Namespace | None = None): "port": port, "lifespan": "on", "log_level": logger.getEffectiveLevel(), + "log_config": logger_config, } if ssl_config: uvicorn_config.update(ssl_config) From 0d7a90b8bc50f31acc12c244456a59f7bd0559e5 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Mon, 21 Jul 2025 13:19:27 -0700 Subject: [PATCH 5/6] chore: merge --config and --template in server.py (#2716) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Part of #2696 ## Test Plan Run `llama stack run starter` Error: ``` myenv ❯ llama stack run starters WARNING 2025-07-10 12:12:43,052 llama_stack.cli.stack.run:82 server: Conda detected. Using conda environment myenv for the run. usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--env KEY=VALUE] [--image-type {conda,venv}] [--enable-ui] [config | template] llama stack run: error: Could not resolve config or template 'starters'. Tried the following locations: 1. As file path: /Users/erichuang/projects/llama-stack-git/starters 2. As template: /Users/erichuang/projects/llama-stack-git/llama_stack/templates/starters/run.yaml 3. As built distribution: (/Users/erichuang/.llama/distributions/llamastack-starters/starters-run.yaml, /Users/erichuang/.llama/distributions/starters/starters-run.yaml) Available templates: dell, test-env, vllm-gpu, test-template, cerebras, openai-api-verification, sambanova, passthrough, direct-config, together, openai, fireworks, meta-reference-gpu, __pycache__, dev, ollama, watsonx, remote-vllm, llama_api, groq, dummy, oracle, nvidia, ci-tests, postgres-demo, test-stack, bedrock, starter, hf-serverless, hf-endpoint, tgi, open-benchmark, verification Did you mean one of these templates? - starter - together - postgres-demo ``` --- llama_stack/cli/stack/_build.py | 4 +- llama_stack/cli/stack/run.py | 52 ++------ llama_stack/cli/utils.py | 31 +++++ llama_stack/distribution/server/server.py | 36 +---- llama_stack/distribution/start_stack.sh | 2 +- .../distribution/utils/config_resolution.py | 125 ++++++++++++++++++ llama_stack/distribution/utils/exec.py | 2 +- 7 files changed, 176 insertions(+), 76 deletions(-) create mode 100644 llama_stack/cli/utils.py create mode 100644 llama_stack/distribution/utils/config_resolution.py diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index b573b2edc..3f94b1e2c 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None: config = parse_and_maybe_upgrade_config(config_dict) if config.external_providers_dir and not config.external_providers_dir.exists(): config.external_providers_dir.mkdir(exist_ok=True) - run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) - run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) + run_args = formulate_run_args(args.image_type, args.image_name) + run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)]) run_command(run_args) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index f4a119522..3cb2e213c 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -82,39 +82,6 @@ class StackRun(Subcommand): return ImageType.CONDA.value, args.image_name return args.image_type, args.image_name - def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]: - """Resolve config file path and template name from args.config""" - from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR - - if not args.config: - return None, None - - config_file = Path(args.config) - has_yaml_suffix = args.config.endswith(".yaml") - template_name = None - - if not config_file.exists() and not has_yaml_suffix: - # check if this is a template - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml" - if config_file.exists(): - template_name = args.config - - if not config_file.exists() and not has_yaml_suffix: - # check if it's a build config saved to ~/.llama dir - config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml") - - if not config_file.exists(): - self.parser.error( - f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file" - ) - - if not config_file.is_file(): - self.parser.error( - f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}" - ) - - return config_file, template_name - def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml @@ -125,8 +92,15 @@ class StackRun(Subcommand): self._start_ui_development_server(args.port) image_type, image_name = self._get_image_type_and_name(args) - # Resolve config file and template name first - config_file, template_name = self._resolve_config_and_template(args) + if args.config: + try: + from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template + + config_file = resolve_config_or_template(args.config, Mode.RUN) + except ValueError as e: + self.parser.error(str(e)) + else: + config_file = None # Check if config is required based on image type if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file: @@ -164,18 +138,14 @@ class StackRun(Subcommand): if callable(getattr(args, arg)): continue if arg == "config": - if template_name: - server_args.template = str(template_name) - else: - # Set the config file path - server_args.config = str(config_file) + server_args.config = str(config_file) else: setattr(server_args, arg, getattr(args, arg)) # Run the server server_main(server_args) else: - run_args = formulate_run_args(image_type, image_name, config, template_name) + run_args = formulate_run_args(image_type, image_name) run_args.extend([str(args.port)]) diff --git a/llama_stack/cli/utils.py b/llama_stack/cli/utils.py new file mode 100644 index 000000000..433627cc0 --- /dev/null +++ b/llama_stack/cli/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + + +def add_config_template_args(parser: argparse.ArgumentParser): + """Add unified config/template arguments with backward compatibility.""" + group = parser.add_mutually_exclusive_group(required=True) + + group.add_argument( + "config", + nargs="?", + help="Configuration file path or template name", + ) + + # Backward compatibility arguments (deprecated) + group.add_argument( + "--config", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Configuration file path", + ) + + group.add_argument( + "--template", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Template name", + ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index ede65e8d6..e58c28f2e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -32,6 +32,7 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.cli.utils import add_config_template_args from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.datatypes import ( AuthenticationRequiredError, @@ -53,6 +54,7 @@ from llama_stack.distribution.stack import ( validate_env_pair, ) from llama_stack.distribution.utils.config import redact_sensitive_fields +from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api @@ -377,20 +379,8 @@ class ClientVersionMiddleware: def main(args: argparse.Namespace | None = None): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") - parser.add_argument( - "--yaml-config", - dest="config", - help="(Deprecated) Path to YAML configuration file - use --config instead", - ) - parser.add_argument( - "--config", - dest="config", - help="Path to YAML configuration file", - ) - parser.add_argument( - "--template", - help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)", - ) + + add_config_template_args(parser) parser.add_argument( "--port", type=int, @@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - log_line = "" - if hasattr(args, "config") and args.config: - # if the user provided a config file, use it, even if template was specified - config_file = Path(args.config) - if not config_file.exists(): - raise ValueError(f"Config file {config_file} does not exist") - log_line = f"Using config file: {config_file}" - elif hasattr(args, "template") and args.template: - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" - if not config_file.exists(): - raise ValueError(f"Template {args.template} does not exist") - log_line = f"Using template {args.template} config file: {config_file}" - else: - raise ValueError("Either --config or --template must be provided") + config_file = resolve_config_or_template(args.config, Mode.RUN) logger_config = None with open(config_file) as fp: @@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None): config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) - # now that the logger is initialized, print the line about which type of config we are using. - logger.info(log_line) - _log_run_config(run_config=config) app = FastAPI( diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index 85bfceec4..77a7dc92e 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then set -x if [ -n "$yaml_config" ]; then - yaml_config_arg="--config $yaml_config" + yaml_config_arg="$yaml_config" else yaml_config_arg="" fi diff --git a/llama_stack/distribution/utils/config_resolution.py b/llama_stack/distribution/utils/config_resolution.py new file mode 100644 index 000000000..7e8de1242 --- /dev/null +++ b/llama_stack/distribution/utils/config_resolution.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import StrEnum +from pathlib import Path + +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="config_resolution") + + +TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates" + + +class Mode(StrEnum): + RUN = "run" + BUILD = "build" + + +def resolve_config_or_template( + config_or_template: str, + mode: Mode = Mode.RUN, +) -> Path: + """ + Resolve a config/template argument to a concrete config file path. + + Args: + config_or_template: User input (file path, template name, or built distribution) + mode: Mode resolving for ("run", "build", "server") + + Returns: + Path to the resolved config file + + Raises: + ValueError: If resolution fails + """ + + # Strategy 1: Try as file path first + config_path = Path(config_or_template) + if config_path.exists() and config_path.is_file(): + logger.info(f"Using file path: {config_path}") + return config_path.resolve() + + # Strategy 2: Try as template name (if no .yaml extension) + if not config_or_template.endswith(".yaml"): + template_config = _get_template_config_path(config_or_template, mode) + if template_config.exists(): + logger.info(f"Using template: {template_config}") + return template_config + + # Strategy 3: Try as built distribution name + distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + # Strategy 4: Failed - provide helpful error + raise ValueError(_format_resolution_error(config_or_template, mode)) + + +def _get_template_config_path(template_name: str, mode: Mode) -> Path: + """Get the config file path for a template.""" + return TEMPLATE_DIR / template_name / f"{mode}.yaml" + + +def _format_resolution_error(config_or_template: str, mode: Mode) -> str: + """Format a helpful error message for resolution failures.""" + from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + + template_path = _get_template_config_path(config_or_template, mode) + distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + + available_templates = _get_available_templates() + templates_str = ", ".join(available_templates) if available_templates else "none found" + + return f"""Could not resolve config or template '{config_or_template}'. + +Tried the following locations: + 1. As file path: {Path(config_or_template).resolve()} + 2. As template: {template_path} + 3. As built distribution: ({distrib_path}, {distrib_path2}) + +Available templates: {templates_str} + +Did you mean one of these templates? +{_format_template_suggestions(available_templates, config_or_template)} +""" + + +def _get_available_templates() -> list[str]: + """Get list of available template names.""" + if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists(): + return [] + + return list( + set( + [d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + + [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + ) + ) + + +def _format_template_suggestions(templates: list[str], user_input: str) -> str: + """Format template suggestions for error messages, showing closest matches first.""" + if not templates: + return " (no templates found)" + + import difflib + + # Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive) + close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3) + display_templates = close_matches if close_matches else templates[:3] + + suggestions = [f" - {t}" for t in display_templates] + return "\n".join(suggestions) diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index 2db01689f..c646ae821 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -21,7 +21,7 @@ from pathlib import Path from llama_stack.distribution.utils.image_types import LlamaStackImageType -def formulate_run_args(image_type, image_name, config, template_name) -> list: +def formulate_run_args(image_type: str, image_name: str) -> list[str]: env_name = "" if image_type == LlamaStackImageType.CONDA.value: From c8f274347d18acd9ed7e657ecf63d8acb0c5ba03 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Mon, 21 Jul 2025 16:22:44 -0400 Subject: [PATCH 6/6] chore: Adding Access Control for OpenAI Vector Stores methods (#2772) # What does this PR do? Refactors the vector store routing logic by moving OpenAI-compatible vector store operations from the `VectorIORouter` to the `VectorDBsRoutingTable`. Closes https://github.com/meta-llama/llama-stack/issues/2761 ## Test Plan Added unit tests to cover new routing logic and ACL checks. --------- Signed-off-by: Francisco Javier Arceo --- llama_stack/distribution/routers/vector_io.py | 43 +-- .../distribution/routing_tables/common.py | 15 + .../distribution/routing_tables/vector_dbs.py | 145 +++++++++ .../routers/test_routing_tables.py | 45 +-- .../distribution/routing_tables/__init__.py | 5 + .../routing_tables/test_vector_dbs.py | 274 ++++++++++++++++++ 6 files changed, 450 insertions(+), 77 deletions(-) create mode 100644 tests/unit/distribution/routing_tables/__init__.py create mode 100644 tests/unit/distribution/routing_tables/test_vector_dbs.py diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index cd56ada7b..a1dd66060 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -214,9 +214,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store(vector_store_id) + return await self.routing_table.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( self, @@ -226,9 +224,7 @@ class VectorIORouter(VectorIO): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store( + return await self.routing_table.openai_update_vector_store( vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -240,12 +236,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - result = await provider.openai_delete_vector_store(vector_store_id) - # drop from registry - await self.routing_table.unregister_vector_db(vector_store_id) - return result + return await self.routing_table.openai_delete_vector_store(vector_store_id) async def openai_search_vector_store( self, @@ -258,9 +249,7 @@ class VectorIORouter(VectorIO): search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_search_vector_store( + return await self.routing_table.openai_search_vector_store( vector_store_id=vector_store_id, query=query, filters=filters, @@ -278,9 +267,7 @@ class VectorIORouter(VectorIO): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_attach_file_to_vector_store( + return await self.routing_table.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -297,9 +284,7 @@ class VectorIORouter(VectorIO): filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_list_files_in_vector_store( + return await self.routing_table.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, order=order, @@ -314,9 +299,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file( + return await self.routing_table.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) @@ -327,9 +310,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileContentsResponse: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( + return await self.routing_table.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, ) @@ -341,9 +322,7 @@ class VectorIORouter(VectorIO): attributes: dict[str, Any], ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store_file( + return await self.routing_table.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -355,9 +334,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_delete_vector_store_file( + return await self.routing_table.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 7f7de32fe..bbe0113e9 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -9,6 +9,7 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.access_control.datatypes import Action from llama_stack.distribution.datatypes import ( AccessRule, RoutableObject, @@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable): await self.dist_registry.register(obj) return obj + async def assert_action_allowed( + self, + action: Action, + type: str, + identifier: str, + ) -> None: + """Fetch a registered object by type/identifier and enforce the given action permission.""" + obj = await self.get_object_by_identifier(type, identifier) + if obj is None: + raise ValueError(f"{type.capitalize()} '{identifier}' not found") + user = get_authenticated_user() + if not is_action_allowed(self.policy, action, obj, user): + raise AccessDeniedError(action, obj, user) + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() filtered_objs = [obj for obj in objs if obj.type == type] diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index f861102c8..b4e60c625 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -4,11 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any + from pydantic import TypeAdapter from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.apis.vector_io.vector_io import ( + SearchRankingOptions, + VectorStoreChunkingStrategy, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreFileStatus, + VectorStoreObject, + VectorStoreSearchResponsePage, +) from llama_stack.distribution.datatypes import ( VectorDBWithOwner, ) @@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if existing_vector_db is None: raise ValueError(f"Vector DB {vector_db_id} not found") await self.unregister_object(existing_vector_db) + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store( + vector_store_id=vector_store_id, + name=name, + expires_after=expires_after, + metadata=metadata, + ) + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) + await self.unregister_vector_db(vector_store_id) + return result + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int | None = 10, + ranking_options: SearchRankingOptions | None = None, + rewrite_query: bool | None = False, + search_mode: str | None = "vector", + ) -> VectorStoreSearchResponsePage: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=filters, + max_num_results=max_num_results, + ranking_options=ranking_options, + rewrite_query=rewrite_query, + search_mode=search_mode, + ) + + async def openai_attach_file_to_vector_store( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any] | None = None, + chunking_strategy: VectorStoreChunkingStrategy | None = None, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_list_files_in_vector_store( + self, + vector_store_id: str, + limit: int | None = 20, + order: str | None = "desc", + after: str | None = None, + before: str | None = None, + filter: VectorStoreFileStatus | None = None, + ) -> list[VectorStoreFileObject]: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( + vector_store_id=vector_store_id, + limit=limit, + order=order, + after=after, + before=before, + filter=filter, + ) + + async def openai_retrieve_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_retrieve_vector_store_file_contents( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileContentsResponse: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_update_vector_store_file( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any], + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + ) + + async def openai_delete_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 3ba042bd9..30f795d33 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,17 +11,15 @@ from unittest.mock import AsyncMock from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import Model from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter -from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable -from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: @@ -54,17 +52,6 @@ class SafetyImpl(Impl): return shield -class VectorDBImpl(Impl): - def __init__(self): - super().__init__(Api.vector_io) - - async def register_vector_db(self, vector_db: VectorDB): - return vector_db - - async def unregister_vector_db(self, vector_db_id: str): - return vector_db_id - - class DatasetsImpl(Impl): def __init__(self): super().__init__(Api.datasetio) @@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry): assert "test-shield-2" in shield_ids -async def test_vectordbs_routing_table(cached_disk_dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) - await table.initialize() - - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") - vector_dbs = await table.list_vector_dbs() - - assert len(vector_dbs.data) == 2 - vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids - - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") - - vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 0 - - async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() diff --git a/tests/unit/distribution/routing_tables/__init__.py b/tests/unit/distribution/routing_tables/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/distribution/routing_tables/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py new file mode 100644 index 000000000..28887e1cf --- /dev/null +++ b/tests/unit/distribution/routing_tables/test_vector_dbs.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Unit tests for the routing tables vector_dbs + +import time +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.datatypes import Api +from llama_stack.apis.models import ModelType +from llama_stack.apis.vector_dbs.vector_dbs import VectorDB +from llama_stack.apis.vector_io.vector_io import ( + VectorStoreContent, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileCounts, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreObject, + VectorStoreSearchResponsePage, +) +from llama_stack.distribution.access_control.datatypes import AccessRule, Scope +from llama_stack.distribution.datatypes import User +from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable +from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable + + +class VectorDBImpl(Impl): + def __init__(self): + super().__init__(Api.vector_io) + + async def register_vector_db(self, vector_db: VectorDB): + return vector_db + + async def unregister_vector_db(self, vector_db_id: str): + return vector_db_id + + async def openai_retrieve_vector_store(self, vector_store_id): + return VectorStoreObject( + id=vector_store_id, + name="Test Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_update_vector_store(self, vector_store_id, **kwargs): + return VectorStoreObject( + id=vector_store_id, + name="Updated Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_delete_vector_store(self, vector_store_id): + return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True) + + async def openai_search_vector_store(self, vector_store_id, query, **kwargs): + return VectorStoreSearchResponsePage( + object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None + ) + + async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs): + return [ + VectorStoreFileObject( + id="1", + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + ] + + async def openai_retrieve_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id): + return VectorStoreFileContentsResponse( + file_id=file_id, + filename="Sample File name", + attributes={"key": "value"}, + content=[VectorStoreContent(type="text", text="Sample content")], + ) + + async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_delete_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileDeleteResponse(id=file_id, deleted=True) + + +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + # Register multiple vector databases and verify listing + await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") + await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") + vector_dbs = await table.list_vector_dbs() + + assert len(vector_dbs.data) == 2 + vector_db_ids = {v.identifier for v in vector_dbs.data} + assert "test-vectordb" in vector_db_ids + assert "test-vectordb-2" in vector_db_ids + + await table.unregister_vector_db(vector_db_id="test-vectordb") + await table.unregister_vector_db(vector_db_id="test-vectordb-2") + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 0 + + +async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): + impl = VectorDBImpl() + impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[]) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + authorized_table = "vs1" + authorized_team = "team1" + unauthorized_team = "team2" + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) + with request_provider_data_context({}, authorized_user): + _ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + + # Authorized reader + with request_provider_data_context({}, authorized_user): + res = await table.openai_retrieve_vector_store(authorized_table) + assert res == "OK" + + # Authorized updater + impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + assert res == "UPDATED" + + # Unauthorized reader + unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]}) + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_retrieve_vector_store(authorized_table) + + # Unauthorized updater + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + + # Authorized deleter + impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + assert res == "DELETED" + + # Unauthorized deleter + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + + +async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry): + impl = VectorDBImpl() + + policy = [ + AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"), + AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"), + ] + + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + + vector_db_id = "vs1" + file_id = "file-1" + + admin_user = User(principal="admin", attributes={"roles": ["admin"]}) + read_only_user = User(principal="reader", attributes={"roles": ["reader"]}) + no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]}) + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + with request_provider_data_context({}, admin_user): + await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + + read_methods = [ + (table.openai_retrieve_vector_store, (vector_db_id,), {}), + (table.openai_search_vector_store, (vector_db_id, "query"), {}), + (table.openai_list_files_in_vector_store, (vector_db_id,), {}), + (table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}), + ] + update_methods = [ + (table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}), + (table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}), + (table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}), + ] + delete_methods = [ + (table.openai_delete_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_delete_vector_store, (vector_db_id,), {}), + ] + + for user in [admin_user, read_only_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in read_methods: + result = await method(*args, **kwargs) + assert result is not None, f"Read operation failed with user {user.principal}" + + with request_provider_data_context({}, no_access_user): + for method, args, kwargs in read_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs) + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in update_methods: + result = await method(*args, **kwargs) + assert result is not None, "Update operation failed with admin user" + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in delete_methods: + result = await method(*args, **kwargs) + assert result is not None, "Delete operation failed with admin user" + + for user in [read_only_user, no_access_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in delete_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs)