diff --git a/.github/workflows/coverage-badge.yml b/.github/workflows/coverage-badge.yml index 6b2f133dd..54bde1749 100644 --- a/.github/workflows/coverage-badge.yml +++ b/.github/workflows/coverage-badge.yml @@ -15,6 +15,9 @@ on: jobs: unit-tests: + permissions: + contents: write # for peter-evans/create-pull-request to create branch + pull-requests: write # for peter-evans/create-pull-request to create a PR runs-on: ubuntu-latest steps: - name: Checkout repository diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index 63ddd9b54..efd1f2cc9 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -20,7 +20,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install uv - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1 with: python-version: ${{ matrix.python-version }} activate-environment: true diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 75b29213c..8d866328b 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -10,8 +10,13 @@ If in doubt, please open a [discussion](https://github.com/meta-llama/llama-stac **I'd like to contribute!** -All issues are actionable (please report if they are not.) Pick one and start working on it. Thank you. -If you need help or guidance, comment on the issue. Issues that are extra friendly to new contributors are tagged with "contributor friendly". +If you are new to the project, start by looking at the issues tagged with "good first issue". If you're interested +leave a comment on the issue and a triager will assign it to you. + +Please avoid picking up too many issues at once. This helps you stay focused and ensures that others in the community also have opportunities to contribute. +- Try to work on only 1–2 issues at a time, especially if you’re still getting familiar with the codebase. +- Before taking an issue, check if it’s already assigned or being actively discussed. +- If you’re blocked or can’t continue with an issue, feel free to unassign yourself or leave a comment so others can step in. **I have a bug!** @@ -41,6 +46,15 @@ If you need help or guidance, comment on the issue. Issues that are extra friend 4. Make sure your code lints using `pre-commit`. 5. If you haven't already, complete the Contributor License Agreement ("CLA"). 6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/). +7. Ensure your pull request follows the [coding style](#coding-style). + + +Please keep pull requests (PRs) small and focused. If you have a large set of changes, consider splitting them into logically grouped, smaller PRs to facilitate review and testing. + +> [!TIP] +> As a general guideline: +> - Experienced contributors should try to keep no more than 5 open PRs at a time. +> - New contributors are encouraged to have only one open PR at a time until they’re familiar with the codebase and process. ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need @@ -140,7 +154,9 @@ uv sync * Don't use unicode characters in the codebase. ASCII-only is preferred for compatibility or readability reasons. * Providers configuration class should be Pydantic Field class. It should have a `description` field - that describes the configuration. These descriptions will be used to generate the provider documentation. + that describes the configuration. These descriptions will be used to generate the provider + documentation. +* When possible, use keyword arguments only when calling functions. ## Common Tasks diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index d7801ba1c..38e53a438 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -13596,9 +13596,6 @@ } }, "additionalProperties": false, - "required": [ - "name" - ], "title": "OpenaiCreateVectorStoreRequest" }, "VectorStoreFileCounts": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index be02e1e42..0df60ddf4 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -9497,8 +9497,6 @@ components: description: >- The ID of the provider to use for this vector store. additionalProperties: false - required: - - name title: OpenaiCreateVectorStoreRequest VectorStoreFileCounts: type: object diff --git a/docs/source/getting_started/demo_script.py b/docs/source/getting_started/demo_script.py new file mode 100644 index 000000000..298fd9899 --- /dev/null +++ b/docs/source/getting_started/demo_script.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient + +vector_db_id = "my_demo_vector_db" +client = LlamaStackClient(base_url="http://localhost:8321") + +models = client.models.list() + +# Select the first LLM and first embedding models +model_id = next(m for m in models if m.model_type == "llm").identifier +embedding_model_id = ( + em := next(m for m in models if m.model_type == "embedding") +).identifier +embedding_dimension = em.metadata["embedding_dimension"] + +_ = client.vector_dbs.register( + vector_db_id=vector_db_id, + embedding_model=embedding_model_id, + embedding_dimension=embedding_dimension, + provider_id="faiss", +) +source = "https://www.paulgraham.com/greatwork.html" +print("rag_tool> Ingesting document:", source) +document = RAGDocument( + document_id="document_1", + content=source, + mime_type="text/html", + metadata={}, +) +client.tool_runtime.rag_tool.insert( + documents=[document], + vector_db_id=vector_db_id, + chunk_size_in_tokens=50, +) +agent = Agent( + client, + model=model_id, + instructions="You are a helpful assistant", + tools=[ + { + "name": "builtin::rag/knowledge_search", + "args": {"vector_db_ids": [vector_db_id]}, + } + ], +) + +prompt = "How do you do great work?" +print("prompt>", prompt) + +response = agent.create_turn( + messages=[{"role": "user", "content": prompt}], + session_id=agent.create_session("rag_session"), + stream=True, +) + +for log in AgentEventLogger().log(response): + log.print() diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index 59791643d..5549f412c 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -24,63 +24,8 @@ ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stac #### Step 3: Run the demo Now open up a new terminal and copy the following script into a file named `demo_script.py`. -```python -from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient - -vector_db_id = "my_demo_vector_db" -client = LlamaStackClient(base_url="http://localhost:8321") - -models = client.models.list() - -# Select the first LLM and first embedding models -model_id = next(m for m in models if m.model_type == "llm").identifier -embedding_model_id = ( - em := next(m for m in models if m.model_type == "embedding") -).identifier -embedding_dimension = em.metadata["embedding_dimension"] - -_ = client.vector_dbs.register( - vector_db_id=vector_db_id, - embedding_model=embedding_model_id, - embedding_dimension=embedding_dimension, - provider_id="faiss", -) -source = "https://www.paulgraham.com/greatwork.html" -print("rag_tool> Ingesting document:", source) -document = RAGDocument( - document_id="document_1", - content=source, - mime_type="text/html", - metadata={}, -) -client.tool_runtime.rag_tool.insert( - documents=[document], - vector_db_id=vector_db_id, - chunk_size_in_tokens=50, -) -agent = Agent( - client, - model=model_id, - instructions="You are a helpful assistant", - tools=[ - { - "name": "builtin::rag/knowledge_search", - "args": {"vector_db_ids": [vector_db_id]}, - } - ], -) - -prompt = "How do you do great work?" -print("prompt>", prompt) - -response = agent.create_turn( - messages=[{"role": "user", "content": prompt}], - session_id=agent.create_session("rag_session"), - stream=True, -) - -for log in AgentEventLogger().log(response): - log.print() +```{literalinclude} ./demo_script.py +:language: python ``` We will use `uv` to run the script ``` diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 6582e08de..dcc6da5b5 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -7,13 +7,10 @@ This section contains documentation for all available providers for the **infere - [remote::anthropic](remote_anthropic.md) - [remote::bedrock](remote_bedrock.md) - [remote::cerebras](remote_cerebras.md) -- [remote::cerebras-openai-compat](remote_cerebras-openai-compat.md) - [remote::databricks](remote_databricks.md) - [remote::fireworks](remote_fireworks.md) -- [remote::fireworks-openai-compat](remote_fireworks-openai-compat.md) - [remote::gemini](remote_gemini.md) - [remote::groq](remote_groq.md) -- [remote::groq-openai-compat](remote_groq-openai-compat.md) - [remote::hf::endpoint](remote_hf_endpoint.md) - [remote::hf::serverless](remote_hf_serverless.md) - [remote::llama-openai-compat](remote_llama-openai-compat.md) @@ -23,9 +20,7 @@ This section contains documentation for all available providers for the **infere - [remote::passthrough](remote_passthrough.md) - [remote::runpod](remote_runpod.md) - [remote::sambanova](remote_sambanova.md) -- [remote::sambanova-openai-compat](remote_sambanova-openai-compat.md) - [remote::tgi](remote_tgi.md) - [remote::together](remote_together.md) -- [remote::together-openai-compat](remote_together-openai-compat.md) - [remote::vllm](remote_vllm.md) - [remote::watsonx](remote_watsonx.md) \ No newline at end of file diff --git a/docs/source/providers/vector_io/remote_pgvector.md b/docs/source/providers/vector_io/remote_pgvector.md index 3e7d6e776..74f588a13 100644 --- a/docs/source/providers/vector_io/remote_pgvector.md +++ b/docs/source/providers/vector_io/remote_pgvector.md @@ -17,7 +17,7 @@ That means you'll get fast and efficient vector retrieval. To use PGVector in your Llama Stack project, follow these steps: 1. Install the necessary dependencies. -2. Configure your Llama Stack project to use Faiss. +2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector). 3. Start storing and querying vectors. ## Installation diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 0d160737a..325e21bab 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -34,6 +34,7 @@ class VectorDBInput(BaseModel): vector_db_id: str embedding_model: str embedding_dimension: int + provider_id: str | None = None provider_vector_db_id: str | None = None diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 618ac2a95..853c4656c 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -338,7 +338,7 @@ class VectorIO(Protocol): @webmethod(route="/openai/v1/vector_stores", method="POST") async def openai_create_vector_store( self, - name: str, + name: str | None = None, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index b573b2edc..3f94b1e2c 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -276,8 +276,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None: config = parse_and_maybe_upgrade_config(config_dict) if config.external_providers_dir and not config.external_providers_dir.exists(): config.external_providers_dir.mkdir(exist_ok=True) - run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) - run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", run_config]) + run_args = formulate_run_args(args.image_type, args.image_name) + run_args.extend([str(os.getenv("LLAMA_STACK_PORT", 8321)), "--config", str(run_config)]) run_command(run_args) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index f4a119522..3cb2e213c 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -82,39 +82,6 @@ class StackRun(Subcommand): return ImageType.CONDA.value, args.image_name return args.image_type, args.image_name - def _resolve_config_and_template(self, args: argparse.Namespace) -> tuple[Path | None, str | None]: - """Resolve config file path and template name from args.config""" - from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR - - if not args.config: - return None, None - - config_file = Path(args.config) - has_yaml_suffix = args.config.endswith(".yaml") - template_name = None - - if not config_file.exists() and not has_yaml_suffix: - # check if this is a template - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml" - if config_file.exists(): - template_name = args.config - - if not config_file.exists() and not has_yaml_suffix: - # check if it's a build config saved to ~/.llama dir - config_file = Path(DISTRIBS_BASE_DIR / f"llamastack-{args.config}" / f"{args.config}-run.yaml") - - if not config_file.exists(): - self.parser.error( - f"File {str(config_file)} does not exist.\n\nPlease run `llama stack build` to generate (and optionally edit) a run.yaml file" - ) - - if not config_file.is_file(): - self.parser.error( - f"Config file must be a valid file path, '{config_file}' is not a file: type={type(config_file)}" - ) - - return config_file, template_name - def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import yaml @@ -125,8 +92,15 @@ class StackRun(Subcommand): self._start_ui_development_server(args.port) image_type, image_name = self._get_image_type_and_name(args) - # Resolve config file and template name first - config_file, template_name = self._resolve_config_and_template(args) + if args.config: + try: + from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template + + config_file = resolve_config_or_template(args.config, Mode.RUN) + except ValueError as e: + self.parser.error(str(e)) + else: + config_file = None # Check if config is required based on image type if (image_type in [ImageType.CONDA.value, ImageType.VENV.value]) and not config_file: @@ -164,18 +138,14 @@ class StackRun(Subcommand): if callable(getattr(args, arg)): continue if arg == "config": - if template_name: - server_args.template = str(template_name) - else: - # Set the config file path - server_args.config = str(config_file) + server_args.config = str(config_file) else: setattr(server_args, arg, getattr(args, arg)) # Run the server server_main(server_args) else: - run_args = formulate_run_args(image_type, image_name, config, template_name) + run_args = formulate_run_args(image_type, image_name) run_args.extend([str(args.port)]) diff --git a/llama_stack/cli/utils.py b/llama_stack/cli/utils.py new file mode 100644 index 000000000..433627cc0 --- /dev/null +++ b/llama_stack/cli/utils.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse + + +def add_config_template_args(parser: argparse.ArgumentParser): + """Add unified config/template arguments with backward compatibility.""" + group = parser.add_mutually_exclusive_group(required=True) + + group.add_argument( + "config", + nargs="?", + help="Configuration file path or template name", + ) + + # Backward compatibility arguments (deprecated) + group.add_argument( + "--config", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Configuration file path", + ) + + group.add_argument( + "--template", + dest="config", + help="(DEPRECATED) Use positional argument [config] instead. Template name", + ) diff --git a/llama_stack/distribution/routers/vector_io.py b/llama_stack/distribution/routers/vector_io.py index cd56ada7b..a1dd66060 100644 --- a/llama_stack/distribution/routers/vector_io.py +++ b/llama_stack/distribution/routers/vector_io.py @@ -214,9 +214,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store(vector_store_id) + return await self.routing_table.openai_retrieve_vector_store(vector_store_id) async def openai_update_vector_store( self, @@ -226,9 +224,7 @@ class VectorIORouter(VectorIO): metadata: dict[str, Any] | None = None, ) -> VectorStoreObject: logger.debug(f"VectorIORouter.openai_update_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store( + return await self.routing_table.openai_update_vector_store( vector_store_id=vector_store_id, name=name, expires_after=expires_after, @@ -240,12 +236,7 @@ class VectorIORouter(VectorIO): vector_store_id: str, ) -> VectorStoreDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - result = await provider.openai_delete_vector_store(vector_store_id) - # drop from registry - await self.routing_table.unregister_vector_db(vector_store_id) - return result + return await self.routing_table.openai_delete_vector_store(vector_store_id) async def openai_search_vector_store( self, @@ -258,9 +249,7 @@ class VectorIORouter(VectorIO): search_mode: str | None = "vector", ) -> VectorStoreSearchResponsePage: logger.debug(f"VectorIORouter.openai_search_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_search_vector_store( + return await self.routing_table.openai_search_vector_store( vector_store_id=vector_store_id, query=query, filters=filters, @@ -278,9 +267,7 @@ class VectorIORouter(VectorIO): chunking_strategy: VectorStoreChunkingStrategy | None = None, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_attach_file_to_vector_store: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_attach_file_to_vector_store( + return await self.routing_table.openai_attach_file_to_vector_store( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -297,9 +284,7 @@ class VectorIORouter(VectorIO): filter: VectorStoreFileStatus | None = None, ) -> list[VectorStoreFileObject]: logger.debug(f"VectorIORouter.openai_list_files_in_vector_store: {vector_store_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_list_files_in_vector_store( + return await self.routing_table.openai_list_files_in_vector_store( vector_store_id=vector_store_id, limit=limit, order=order, @@ -314,9 +299,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file( + return await self.routing_table.openai_retrieve_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) @@ -327,9 +310,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileContentsResponse: logger.debug(f"VectorIORouter.openai_retrieve_vector_store_file_contents: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_retrieve_vector_store_file_contents( + return await self.routing_table.openai_retrieve_vector_store_file_contents( vector_store_id=vector_store_id, file_id=file_id, ) @@ -341,9 +322,7 @@ class VectorIORouter(VectorIO): attributes: dict[str, Any], ) -> VectorStoreFileObject: logger.debug(f"VectorIORouter.openai_update_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_update_vector_store_file( + return await self.routing_table.openai_update_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, attributes=attributes, @@ -355,9 +334,7 @@ class VectorIORouter(VectorIO): file_id: str, ) -> VectorStoreFileDeleteResponse: logger.debug(f"VectorIORouter.openai_delete_vector_store_file: {vector_store_id}, {file_id}") - # Route based on vector store ID - provider = self.routing_table.get_provider_impl(vector_store_id) - return await provider.openai_delete_vector_store_file( + return await self.routing_table.openai_delete_vector_store_file( vector_store_id=vector_store_id, file_id=file_id, ) diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index 7f7de32fe..bbe0113e9 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -9,6 +9,7 @@ from typing import Any from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn from llama_stack.distribution.access_control.access_control import AccessDeniedError, is_action_allowed +from llama_stack.distribution.access_control.datatypes import Action from llama_stack.distribution.datatypes import ( AccessRule, RoutableObject, @@ -209,6 +210,20 @@ class CommonRoutingTableImpl(RoutingTable): await self.dist_registry.register(obj) return obj + async def assert_action_allowed( + self, + action: Action, + type: str, + identifier: str, + ) -> None: + """Fetch a registered object by type/identifier and enforce the given action permission.""" + obj = await self.get_object_by_identifier(type, identifier) + if obj is None: + raise ValueError(f"{type.capitalize()} '{identifier}' not found") + user = get_authenticated_user() + if not is_action_allowed(self.policy, action, obj, user): + raise AccessDeniedError(action, obj, user) + async def get_all_with_type(self, type: str) -> list[RoutableObjectWithProvider]: objs = await self.dist_registry.get_all() filtered_objs = [obj for obj in objs if obj.type == type] diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index f861102c8..b4e60c625 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -4,11 +4,24 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from typing import Any + from pydantic import TypeAdapter from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs +from llama_stack.apis.vector_io.vector_io import ( + SearchRankingOptions, + VectorStoreChunkingStrategy, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreFileStatus, + VectorStoreObject, + VectorStoreSearchResponsePage, +) from llama_stack.distribution.datatypes import ( VectorDBWithOwner, ) @@ -74,3 +87,135 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if existing_vector_db is None: raise ValueError(f"Vector DB {vector_db_id} not found") await self.unregister_object(existing_vector_db) + + async def openai_retrieve_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store(vector_store_id) + + async def openai_update_vector_store( + self, + vector_store_id: str, + name: str | None = None, + expires_after: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, + ) -> VectorStoreObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store( + vector_store_id=vector_store_id, + name=name, + expires_after=expires_after, + metadata=metadata, + ) + + async def openai_delete_vector_store( + self, + vector_store_id: str, + ) -> VectorStoreDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + result = await self.get_provider_impl(vector_store_id).openai_delete_vector_store(vector_store_id) + await self.unregister_vector_db(vector_store_id) + return result + + async def openai_search_vector_store( + self, + vector_store_id: str, + query: str | list[str], + filters: dict[str, Any] | None = None, + max_num_results: int | None = 10, + ranking_options: SearchRankingOptions | None = None, + rewrite_query: bool | None = False, + search_mode: str | None = "vector", + ) -> VectorStoreSearchResponsePage: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_search_vector_store( + vector_store_id=vector_store_id, + query=query, + filters=filters, + max_num_results=max_num_results, + ranking_options=ranking_options, + rewrite_query=rewrite_query, + search_mode=search_mode, + ) + + async def openai_attach_file_to_vector_store( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any] | None = None, + chunking_strategy: VectorStoreChunkingStrategy | None = None, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_attach_file_to_vector_store( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + chunking_strategy=chunking_strategy, + ) + + async def openai_list_files_in_vector_store( + self, + vector_store_id: str, + limit: int | None = 20, + order: str | None = "desc", + after: str | None = None, + before: str | None = None, + filter: VectorStoreFileStatus | None = None, + ) -> list[VectorStoreFileObject]: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_list_files_in_vector_store( + vector_store_id=vector_store_id, + limit=limit, + order=order, + after=after, + before=before, + filter=filter, + ) + + async def openai_retrieve_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileObject: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_retrieve_vector_store_file_contents( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileContentsResponse: + await self.assert_action_allowed("read", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_retrieve_vector_store_file_contents( + vector_store_id=vector_store_id, + file_id=file_id, + ) + + async def openai_update_vector_store_file( + self, + vector_store_id: str, + file_id: str, + attributes: dict[str, Any], + ) -> VectorStoreFileObject: + await self.assert_action_allowed("update", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_update_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + attributes=attributes, + ) + + async def openai_delete_vector_store_file( + self, + vector_store_id: str, + file_id: str, + ) -> VectorStoreFileDeleteResponse: + await self.assert_action_allowed("delete", "vector_db", vector_store_id) + return await self.get_provider_impl(vector_store_id).openai_delete_vector_store_file( + vector_store_id=vector_store_id, + file_id=file_id, + ) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index e7e9e5e88..e58c28f2e 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -32,6 +32,7 @@ from openai import BadRequestError from pydantic import BaseModel, ValidationError from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.cli.utils import add_config_template_args from llama_stack.distribution.access_control.access_control import AccessDeniedError from llama_stack.distribution.datatypes import ( AuthenticationRequiredError, @@ -53,6 +54,7 @@ from llama_stack.distribution.stack import ( validate_env_pair, ) from llama_stack.distribution.utils.config import redact_sensitive_fields +from llama_stack.distribution.utils.config_resolution import Mode, resolve_config_or_template from llama_stack.distribution.utils.context import preserve_contexts_async_generator from llama_stack.log import get_logger from llama_stack.providers.datatypes import Api @@ -377,20 +379,8 @@ class ClientVersionMiddleware: def main(args: argparse.Namespace | None = None): """Start the LlamaStack server.""" parser = argparse.ArgumentParser(description="Start the LlamaStack server.") - parser.add_argument( - "--yaml-config", - dest="config", - help="(Deprecated) Path to YAML configuration file - use --config instead", - ) - parser.add_argument( - "--config", - dest="config", - help="Path to YAML configuration file", - ) - parser.add_argument( - "--template", - help="One of the template names in llama_stack/templates (e.g., tgi, fireworks, remote-vllm, etc.)", - ) + + add_config_template_args(parser) parser.add_argument( "--port", type=int, @@ -409,20 +399,7 @@ def main(args: argparse.Namespace | None = None): if args is None: args = parser.parse_args() - log_line = "" - if hasattr(args, "config") and args.config: - # if the user provided a config file, use it, even if template was specified - config_file = Path(args.config) - if not config_file.exists(): - raise ValueError(f"Config file {config_file} does not exist") - log_line = f"Using config file: {config_file}" - elif hasattr(args, "template") and args.template: - config_file = Path(REPO_ROOT) / "llama_stack" / "templates" / args.template / "run.yaml" - if not config_file.exists(): - raise ValueError(f"Template {args.template} does not exist") - log_line = f"Using template {args.template} config file: {config_file}" - else: - raise ValueError("Either --config or --template must be provided") + config_file = resolve_config_or_template(args.config, Mode.RUN) logger_config = None with open(config_file) as fp: @@ -442,9 +419,6 @@ def main(args: argparse.Namespace | None = None): config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) - # now that the logger is initialized, print the line about which type of config we are using. - logger.info(log_line) - _log_run_config(run_config=config) app = FastAPI( @@ -592,12 +566,29 @@ def main(args: argparse.Namespace | None = None): "port": port, "lifespan": "on", "log_level": logger.getEffectiveLevel(), + "log_config": logger_config, } if ssl_config: uvicorn_config.update(ssl_config) # Run uvicorn in the existing event loop to preserve background tasks - loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + # We need to catch KeyboardInterrupt because uvicorn's signal handling + # re-raises SIGINT signals using signal.raise_signal(), which Python + # converts to KeyboardInterrupt. Without this catch, we'd get a confusing + # stack trace when using Ctrl+C or kill -2 (SIGINT). + # SIGTERM (kill -15) works fine without this because Python doesn't + # have a default handler for it. + # + # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own + # signal handling but this is quite intrusive and not worth the effort. + try: + loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + except (KeyboardInterrupt, SystemExit): + logger.info("Received interrupt signal, shutting down gracefully...") + finally: + if not loop.is_closed(): + logger.debug("Closing event loop") + loop.close() def _log_run_config(run_config: StackRunConfig): diff --git a/llama_stack/distribution/start_stack.sh b/llama_stack/distribution/start_stack.sh index 85bfceec4..77a7dc92e 100755 --- a/llama_stack/distribution/start_stack.sh +++ b/llama_stack/distribution/start_stack.sh @@ -117,7 +117,7 @@ if [[ "$env_type" == "venv" || "$env_type" == "conda" ]]; then set -x if [ -n "$yaml_config" ]; then - yaml_config_arg="--config $yaml_config" + yaml_config_arg="$yaml_config" else yaml_config_arg="" fi diff --git a/llama_stack/distribution/utils/config_resolution.py b/llama_stack/distribution/utils/config_resolution.py new file mode 100644 index 000000000..7e8de1242 --- /dev/null +++ b/llama_stack/distribution/utils/config_resolution.py @@ -0,0 +1,125 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import StrEnum +from pathlib import Path + +from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="config_resolution") + + +TEMPLATE_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "templates" + + +class Mode(StrEnum): + RUN = "run" + BUILD = "build" + + +def resolve_config_or_template( + config_or_template: str, + mode: Mode = Mode.RUN, +) -> Path: + """ + Resolve a config/template argument to a concrete config file path. + + Args: + config_or_template: User input (file path, template name, or built distribution) + mode: Mode resolving for ("run", "build", "server") + + Returns: + Path to the resolved config file + + Raises: + ValueError: If resolution fails + """ + + # Strategy 1: Try as file path first + config_path = Path(config_or_template) + if config_path.exists() and config_path.is_file(): + logger.info(f"Using file path: {config_path}") + return config_path.resolve() + + # Strategy 2: Try as template name (if no .yaml extension) + if not config_or_template.endswith(".yaml"): + template_config = _get_template_config_path(config_or_template, mode) + if template_config.exists(): + logger.info(f"Using template: {template_config}") + return template_config + + # Strategy 3: Try as built distribution name + distrib_config = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + distrib_config = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + if distrib_config.exists(): + logger.info(f"Using built distribution: {distrib_config}") + return distrib_config + + # Strategy 4: Failed - provide helpful error + raise ValueError(_format_resolution_error(config_or_template, mode)) + + +def _get_template_config_path(template_name: str, mode: Mode) -> Path: + """Get the config file path for a template.""" + return TEMPLATE_DIR / template_name / f"{mode}.yaml" + + +def _format_resolution_error(config_or_template: str, mode: Mode) -> str: + """Format a helpful error message for resolution failures.""" + from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR + + template_path = _get_template_config_path(config_or_template, mode) + distrib_path = DISTRIBS_BASE_DIR / f"llamastack-{config_or_template}" / f"{config_or_template}-{mode}.yaml" + distrib_path2 = DISTRIBS_BASE_DIR / f"{config_or_template}" / f"{config_or_template}-{mode}.yaml" + + available_templates = _get_available_templates() + templates_str = ", ".join(available_templates) if available_templates else "none found" + + return f"""Could not resolve config or template '{config_or_template}'. + +Tried the following locations: + 1. As file path: {Path(config_or_template).resolve()} + 2. As template: {template_path} + 3. As built distribution: ({distrib_path}, {distrib_path2}) + +Available templates: {templates_str} + +Did you mean one of these templates? +{_format_template_suggestions(available_templates, config_or_template)} +""" + + +def _get_available_templates() -> list[str]: + """Get list of available template names.""" + if not TEMPLATE_DIR.exists() and not DISTRIBS_BASE_DIR.exists(): + return [] + + return list( + set( + [d.name for d in TEMPLATE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + + [d.name for d in DISTRIBS_BASE_DIR.iterdir() if d.is_dir() and not d.name.startswith(".")] + ) + ) + + +def _format_template_suggestions(templates: list[str], user_input: str) -> str: + """Format template suggestions for error messages, showing closest matches first.""" + if not templates: + return " (no templates found)" + + import difflib + + # Get up to 3 closest matches with similarity threshold of 0.3 (lower = more permissive) + close_matches = difflib.get_close_matches(user_input, templates, n=3, cutoff=0.3) + display_templates = close_matches if close_matches else templates[:3] + + suggestions = [f" - {t}" for t in display_templates] + return "\n".join(suggestions) diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index 2db01689f..c646ae821 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -21,7 +21,7 @@ from pathlib import Path from llama_stack.distribution.utils.image_types import LlamaStackImageType -def formulate_run_args(image_type, image_name, config, template_name) -> list: +def formulate_run_args(image_type: str, image_name: str) -> list[str]: env_name = "" if image_type == LlamaStackImageType.CONDA.value: diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4d2b9f8bf..3c34c71fb 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -10,6 +10,7 @@ import re import secrets import string import uuid +import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime @@ -911,8 +912,16 @@ async def load_data_from_url(url: str) -> str: async def get_raw_document_text(document: Document) -> str: - if not document.mime_type.startswith("text/"): + # Handle deprecated text/yaml mime type with warning + if document.mime_type == "text/yaml": + warnings.warn( + "The 'text/yaml' MIME type is deprecated. Please use 'application/yaml' instead.", + DeprecationWarning, + stacklevel=2, + ) + elif not (document.mime_type.startswith("text/") or document.mime_type == "application/yaml"): raise ValueError(f"Unexpected document mime type: {document.mime_type}") + if isinstance(document.content, URL): return await load_data_from_url(document.content.uri) elif isinstance(document.content, str): diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index cda535937..437d617ad 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -128,6 +128,11 @@ class AgentPersistence: except Exception as e: log.error(f"Error parsing turn: {e}") continue + + # The kvstore does not guarantee order, so we sort by started_at + # to ensure consistent ordering of turns. + turns.sort(key=lambda t: t.started_at) + return turns async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index ffd30a5b5..a8bc96a77 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -224,17 +224,6 @@ def available_providers() -> list[ProviderSpec]: description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", ), ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="fireworks-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.fireworks_openai_compat", - config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator", - description="Fireworks AI OpenAI-compatible provider for using Fireworks models with OpenAI API format.", - ), - ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( @@ -246,50 +235,6 @@ def available_providers() -> list[ProviderSpec]: description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", ), ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="together-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.together_openai_compat", - config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator", - description="Together AI OpenAI-compatible provider for using Together models with OpenAI API format.", - ), - ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="groq-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.groq_openai_compat", - config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator", - description="Groq OpenAI-compatible provider for using Groq models with OpenAI API format.", - ), - ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="sambanova-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.sambanova_openai_compat", - config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator", - description="SambaNova OpenAI-compatible provider for using SambaNova models with OpenAI API format.", - ), - ), - remote_provider_spec( - api=Api.inference, - adapter=AdapterSpec( - adapter_type="cerebras-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.cerebras_openai_compat", - config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator", - description="Cerebras OpenAI-compatible provider for using Cerebras models with OpenAI API format.", - ), - ), remote_provider_spec( api=Api.inference, adapter=AdapterSpec( diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index c13e65bbc..e391341b4 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -395,7 +395,7 @@ That means you'll get fast and efficient vector retrieval. To use PGVector in your Llama Stack project, follow these steps: 1. Install the necessary dependencies. -2. Configure your Llama Stack project to use Faiss. +2. Configure your Llama Stack project to use pgvector. (e.g. remote::pgvector). 3. Start storing and querying vectors. ## Installation diff --git a/llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py b/llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py deleted file mode 100644 index 523a8dfe7..000000000 --- a/llama_stack/providers/remote/inference/cerebras_openai_compat/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.inference import InferenceProvider - -from .config import CerebrasCompatConfig - - -async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> InferenceProvider: - # import dynamically so the import is used only when it is needed - from .cerebras import CerebrasCompatInferenceAdapter - - adapter = CerebrasCompatInferenceAdapter(config) - return adapter diff --git a/llama_stack/providers/remote/inference/cerebras_openai_compat/cerebras.py b/llama_stack/providers/remote/inference/cerebras_openai_compat/cerebras.py deleted file mode 100644 index b3f109dcc..000000000 --- a/llama_stack/providers/remote/inference/cerebras_openai_compat/cerebras.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig -from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin - -from ..cerebras.models import MODEL_ENTRIES - - -class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin): - _config: CerebrasCompatConfig - - def __init__(self, config: CerebrasCompatConfig): - LiteLLMOpenAIMixin.__init__( - self, - model_entries=MODEL_ENTRIES, - api_key_from_config=config.api_key, - provider_data_api_key_field="cerebras_api_key", - openai_compat_api_base=config.openai_compat_api_base, - ) - self.config = config - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/remote/inference/cerebras_openai_compat/config.py b/llama_stack/providers/remote/inference/cerebras_openai_compat/config.py deleted file mode 100644 index cb8daff6a..000000000 --- a/llama_stack/providers/remote/inference/cerebras_openai_compat/config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type - - -class CerebrasProviderDataValidator(BaseModel): - cerebras_api_key: str | None = Field( - default=None, - description="API key for Cerebras models", - ) - - -@json_schema_type -class CerebrasCompatConfig(BaseModel): - api_key: str | None = Field( - default=None, - description="The Cerebras API key", - ) - - openai_compat_api_base: str = Field( - default="https://api.cerebras.ai/v1", - description="The URL for the Cerebras API server", - ) - - @classmethod - def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> dict[str, Any]: - return { - "openai_compat_api_base": "https://api.cerebras.ai/v1", - "api_key": api_key, - } diff --git a/llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py b/llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py deleted file mode 100644 index 15a666cb6..000000000 --- a/llama_stack/providers/remote/inference/fireworks_openai_compat/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.inference import InferenceProvider - -from .config import FireworksCompatConfig - - -async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> InferenceProvider: - # import dynamically so the import is used only when it is needed - from .fireworks import FireworksCompatInferenceAdapter - - adapter = FireworksCompatInferenceAdapter(config) - return adapter diff --git a/llama_stack/providers/remote/inference/fireworks_openai_compat/config.py b/llama_stack/providers/remote/inference/fireworks_openai_compat/config.py deleted file mode 100644 index bf38cdd2b..000000000 --- a/llama_stack/providers/remote/inference/fireworks_openai_compat/config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type - - -class FireworksProviderDataValidator(BaseModel): - fireworks_api_key: str | None = Field( - default=None, - description="API key for Fireworks models", - ) - - -@json_schema_type -class FireworksCompatConfig(BaseModel): - api_key: str | None = Field( - default=None, - description="The Fireworks API key", - ) - - openai_compat_api_base: str = Field( - default="https://api.fireworks.ai/inference/v1", - description="The URL for the Fireworks API server", - ) - - @classmethod - def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> dict[str, Any]: - return { - "openai_compat_api_base": "https://api.fireworks.ai/inference/v1", - "api_key": api_key, - } diff --git a/llama_stack/providers/remote/inference/fireworks_openai_compat/fireworks.py b/llama_stack/providers/remote/inference/fireworks_openai_compat/fireworks.py deleted file mode 100644 index f6045e0eb..000000000 --- a/llama_stack/providers/remote/inference/fireworks_openai_compat/fireworks.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig -from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin - -from ..fireworks.models import MODEL_ENTRIES - - -class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin): - _config: FireworksCompatConfig - - def __init__(self, config: FireworksCompatConfig): - LiteLLMOpenAIMixin.__init__( - self, - model_entries=MODEL_ENTRIES, - api_key_from_config=config.api_key, - provider_data_api_key_field="fireworks_api_key", - openai_compat_api_base=config.openai_compat_api_base, - ) - self.config = config - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/remote/inference/groq_openai_compat/__init__.py b/llama_stack/providers/remote/inference/groq_openai_compat/__init__.py deleted file mode 100644 index 794cdebd7..000000000 --- a/llama_stack/providers/remote/inference/groq_openai_compat/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.inference import InferenceProvider - -from .config import GroqCompatConfig - - -async def get_adapter_impl(config: GroqCompatConfig, _deps) -> InferenceProvider: - # import dynamically so the import is used only when it is needed - from .groq import GroqCompatInferenceAdapter - - adapter = GroqCompatInferenceAdapter(config) - return adapter diff --git a/llama_stack/providers/remote/inference/groq_openai_compat/config.py b/llama_stack/providers/remote/inference/groq_openai_compat/config.py deleted file mode 100644 index 481f740f9..000000000 --- a/llama_stack/providers/remote/inference/groq_openai_compat/config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type - - -class GroqProviderDataValidator(BaseModel): - groq_api_key: str | None = Field( - default=None, - description="API key for Groq models", - ) - - -@json_schema_type -class GroqCompatConfig(BaseModel): - api_key: str | None = Field( - default=None, - description="The Groq API key", - ) - - openai_compat_api_base: str = Field( - default="https://api.groq.com/openai/v1", - description="The URL for the Groq API server", - ) - - @classmethod - def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> dict[str, Any]: - return { - "openai_compat_api_base": "https://api.groq.com/openai/v1", - "api_key": api_key, - } diff --git a/llama_stack/providers/remote/inference/groq_openai_compat/groq.py b/llama_stack/providers/remote/inference/groq_openai_compat/groq.py deleted file mode 100644 index 30e18cd06..000000000 --- a/llama_stack/providers/remote/inference/groq_openai_compat/groq.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig -from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin - -from ..groq.models import MODEL_ENTRIES - - -class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin): - _config: GroqCompatConfig - - def __init__(self, config: GroqCompatConfig): - LiteLLMOpenAIMixin.__init__( - self, - model_entries=MODEL_ENTRIES, - api_key_from_config=config.api_key, - provider_data_api_key_field="groq_api_key", - openai_compat_api_base=config.openai_compat_api_base, - ) - self.config = config - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py b/llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py deleted file mode 100644 index 60afe91ca..000000000 --- a/llama_stack/providers/remote/inference/sambanova_openai_compat/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.inference import InferenceProvider - -from .config import SambaNovaCompatConfig - - -async def get_adapter_impl(config: SambaNovaCompatConfig, _deps) -> InferenceProvider: - # import dynamically so the import is used only when it is needed - from .sambanova import SambaNovaCompatInferenceAdapter - - adapter = SambaNovaCompatInferenceAdapter(config) - return adapter diff --git a/llama_stack/providers/remote/inference/sambanova_openai_compat/config.py b/llama_stack/providers/remote/inference/sambanova_openai_compat/config.py deleted file mode 100644 index 072fa85d1..000000000 --- a/llama_stack/providers/remote/inference/sambanova_openai_compat/config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type - - -class SambaNovaProviderDataValidator(BaseModel): - sambanova_api_key: str | None = Field( - default=None, - description="API key for SambaNova models", - ) - - -@json_schema_type -class SambaNovaCompatConfig(BaseModel): - api_key: str | None = Field( - default=None, - description="The SambaNova API key", - ) - - openai_compat_api_base: str = Field( - default="https://api.sambanova.ai/v1", - description="The URL for the SambaNova API server", - ) - - @classmethod - def sample_run_config(cls, api_key: str = "${env.SAMBANOVA_API_KEY}", **kwargs) -> dict[str, Any]: - return { - "openai_compat_api_base": "https://api.sambanova.ai/v1", - "api_key": api_key, - } diff --git a/llama_stack/providers/remote/inference/sambanova_openai_compat/sambanova.py b/llama_stack/providers/remote/inference/sambanova_openai_compat/sambanova.py deleted file mode 100644 index aa59028b6..000000000 --- a/llama_stack/providers/remote/inference/sambanova_openai_compat/sambanova.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.providers.remote.inference.sambanova_openai_compat.config import SambaNovaCompatConfig -from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin - -from ..sambanova.models import MODEL_ENTRIES - - -class SambaNovaCompatInferenceAdapter(LiteLLMOpenAIMixin): - _config: SambaNovaCompatConfig - - def __init__(self, config: SambaNovaCompatConfig): - LiteLLMOpenAIMixin.__init__( - self, - model_entries=MODEL_ENTRIES, - api_key_from_config=config.api_key, - provider_data_api_key_field="sambanova_api_key", - openai_compat_api_base=config.openai_compat_api_base, - ) - self.config = config - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/remote/inference/together_openai_compat/__init__.py b/llama_stack/providers/remote/inference/together_openai_compat/__init__.py deleted file mode 100644 index 8213fc5f4..000000000 --- a/llama_stack/providers/remote/inference/together_openai_compat/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.inference import InferenceProvider - -from .config import TogetherCompatConfig - - -async def get_adapter_impl(config: TogetherCompatConfig, _deps) -> InferenceProvider: - # import dynamically so the import is used only when it is needed - from .together import TogetherCompatInferenceAdapter - - adapter = TogetherCompatInferenceAdapter(config) - return adapter diff --git a/llama_stack/providers/remote/inference/together_openai_compat/config.py b/llama_stack/providers/remote/inference/together_openai_compat/config.py deleted file mode 100644 index 0c6d4f748..000000000 --- a/llama_stack/providers/remote/inference/together_openai_compat/config.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type - - -class TogetherProviderDataValidator(BaseModel): - together_api_key: str | None = Field( - default=None, - description="API key for Together models", - ) - - -@json_schema_type -class TogetherCompatConfig(BaseModel): - api_key: str | None = Field( - default=None, - description="The Together API key", - ) - - openai_compat_api_base: str = Field( - default="https://api.together.xyz/v1", - description="The URL for the Together API server", - ) - - @classmethod - def sample_run_config(cls, api_key: str = "${env.TOGETHER_API_KEY}", **kwargs) -> dict[str, Any]: - return { - "openai_compat_api_base": "https://api.together.xyz/v1", - "api_key": api_key, - } diff --git a/llama_stack/providers/remote/inference/together_openai_compat/together.py b/llama_stack/providers/remote/inference/together_openai_compat/together.py deleted file mode 100644 index b463f5c35..000000000 --- a/llama_stack/providers/remote/inference/together_openai_compat/together.py +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.providers.remote.inference.together_openai_compat.config import TogetherCompatConfig -from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin - -from ..together.models import MODEL_ENTRIES - - -class TogetherCompatInferenceAdapter(LiteLLMOpenAIMixin): - _config: TogetherCompatConfig - - def __init__(self, config: TogetherCompatConfig): - LiteLLMOpenAIMixin.__init__( - self, - model_entries=MODEL_ENTRIES, - api_key_from_config=config.api_key, - provider_data_api_key_field="together_api_key", - openai_compat_api_base=config.openai_compat_api_base, - ) - self.config = config - - async def initialize(self): - await super().initialize() - - async def shutdown(self): - await super().shutdown() diff --git a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py index e5328bc59..c11de396b 100644 --- a/llama_stack/providers/utils/memory/openai_vector_store_mixin.py +++ b/llama_stack/providers/utils/memory/openai_vector_store_mixin.py @@ -161,7 +161,7 @@ class OpenAIVectorStoreMixin(ABC): async def openai_create_vector_store( self, - name: str, + name: str | None = None, file_ids: list[str] | None = None, expires_after: dict[str, Any] | None = None, chunking_strategy: dict[str, Any] | None = None, diff --git a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py index af1145fe7..8dd6061a6 100644 --- a/llama_stack/providers/utils/telemetry/sqlite_trace_store.py +++ b/llama_stack/providers/utils/telemetry/sqlite_trace_store.py @@ -83,6 +83,7 @@ class SQLiteTraceStore(TraceStore): ) SELECT DISTINCT trace_id, root_span_id, start_time, end_time FROM filtered_traces + WHERE root_span_id IS NOT NULL LIMIT {limit} OFFSET {offset} """ @@ -166,7 +167,11 @@ class SQLiteTraceStore(TraceStore): return spans_by_id async def get_trace(self, trace_id: str) -> Trace: - query = "SELECT * FROM traces WHERE trace_id = ?" + query = """ + SELECT * + FROM traces t + WHERE t.trace_id = ? + """ async with aiosqlite.connect(self.conn_string) as conn: conn.row_factory = aiosqlite.Row async with conn.execute(query, (trace_id,)) as cursor: diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml index 0aed1d185..625e36e4f 100644 --- a/llama_stack/templates/ci-tests/build.yaml +++ b/llama_stack/templates/ci-tests/build.yaml @@ -19,12 +19,7 @@ distribution_spec: - remote::anthropic - remote::gemini - remote::groq - - remote::fireworks-openai-compat - remote::llama-openai-compat - - remote::together-openai-compat - - remote::groq-openai-compat - - remote::sambanova-openai-compat - - remote::cerebras-openai-compat - remote::sambanova - remote::passthrough - inline::sentence-transformers diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml index cc7378c97..3757c6e60 100644 --- a/llama_stack/templates/ci-tests/run.yaml +++ b/llama_stack/templates/ci-tests/run.yaml @@ -90,36 +90,11 @@ providers: config: url: https://api.groq.com api_key: ${env.GROQ_API_KEY} - - provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__} - provider_type: remote::fireworks-openai-compat - config: - openai_compat_api_base: https://api.fireworks.ai/inference/v1 - api_key: ${env.FIREWORKS_API_KEY} - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} provider_type: remote::llama-openai-compat config: openai_compat_api_base: https://api.llama.com/compat/v1/ api_key: ${env.LLAMA_API_KEY} - - provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__} - provider_type: remote::together-openai-compat - config: - openai_compat_api_base: https://api.together.xyz/v1 - api_key: ${env.TOGETHER_API_KEY} - - provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__} - provider_type: remote::groq-openai-compat - config: - openai_compat_api_base: https://api.groq.com/openai/v1 - api_key: ${env.GROQ_API_KEY} - - provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__} - provider_type: remote::sambanova-openai-compat - config: - openai_compat_api_base: https://api.sambanova.ai/v1 - api_key: ${env.SAMBANOVA_API_KEY} - - provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__} - provider_type: remote::cerebras-openai-compat - config: - openai_compat_api_base: https://api.cerebras.ai/v1 - api_key: ${env.CEREBRAS_API_KEY} - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} provider_type: remote::sambanova config: diff --git a/llama_stack/templates/starter/build.yaml b/llama_stack/templates/starter/build.yaml index dc7565d46..8180124f6 100644 --- a/llama_stack/templates/starter/build.yaml +++ b/llama_stack/templates/starter/build.yaml @@ -19,12 +19,7 @@ distribution_spec: - remote::anthropic - remote::gemini - remote::groq - - remote::fireworks-openai-compat - remote::llama-openai-compat - - remote::together-openai-compat - - remote::groq-openai-compat - - remote::sambanova-openai-compat - - remote::cerebras-openai-compat - remote::sambanova - remote::passthrough - inline::sentence-transformers diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 46573848c..62e96d3b5 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -90,36 +90,11 @@ providers: config: url: https://api.groq.com api_key: ${env.GROQ_API_KEY} - - provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__} - provider_type: remote::fireworks-openai-compat - config: - openai_compat_api_base: https://api.fireworks.ai/inference/v1 - api_key: ${env.FIREWORKS_API_KEY} - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} provider_type: remote::llama-openai-compat config: openai_compat_api_base: https://api.llama.com/compat/v1/ api_key: ${env.LLAMA_API_KEY} - - provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__} - provider_type: remote::together-openai-compat - config: - openai_compat_api_base: https://api.together.xyz/v1 - api_key: ${env.TOGETHER_API_KEY} - - provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__} - provider_type: remote::groq-openai-compat - config: - openai_compat_api_base: https://api.groq.com/openai/v1 - api_key: ${env.GROQ_API_KEY} - - provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__} - provider_type: remote::sambanova-openai-compat - config: - openai_compat_api_base: https://api.sambanova.ai/v1 - api_key: ${env.SAMBANOVA_API_KEY} - - provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__} - provider_type: remote::cerebras-openai-compat - config: - openai_compat_api_base: https://api.cerebras.ai/v1 - api_key: ${env.CEREBRAS_API_KEY} - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} provider_type: remote::sambanova config: diff --git a/scripts/install.sh b/scripts/install.sh index dae43df38..b5afe43b8 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -15,7 +15,7 @@ set -Eeuo pipefail PORT=8321 OLLAMA_PORT=11434 MODEL_ALIAS="llama3.2:3b" -SERVER_IMAGE="llamastack/distribution-ollama:0.2.2" +SERVER_IMAGE="docker.io/llamastack/distribution-ollama:0.2.2" WAIT_TIMEOUT=300 log(){ printf "\e[1;32m%s\e[0m\n" "$*"; } @@ -165,7 +165,7 @@ log "🦙 Starting Ollama…" $ENGINE run -d "${PLATFORM_OPTS[@]}" --name ollama-server \ --network llama-net \ -p "${OLLAMA_PORT}:${OLLAMA_PORT}" \ - ollama/ollama > /dev/null 2>&1 + docker.io/ollama/ollama > /dev/null 2>&1 if ! wait_for_service "http://localhost:${OLLAMA_PORT}/" "Ollama" "$WAIT_TIMEOUT" "Ollama daemon"; then log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:" diff --git a/tests/integration/telemetry/test_telemetry.py b/tests/integration/telemetry/test_telemetry.py index 9df03da70..d363edbc0 100644 --- a/tests/integration/telemetry/test_telemetry.py +++ b/tests/integration/telemetry/test_telemetry.py @@ -47,6 +47,9 @@ def setup_telemetry_data(llama_stack_client, text_model_id): if len(traces) < 4: pytest.fail(f"Failed to create sufficient telemetry data after 30s. Got {len(traces)} traces.") + # Wait for 5 seconds to ensure traces has completed logging + time.sleep(5) + yield diff --git a/tests/unit/distribution/routers/test_routing_tables.py b/tests/unit/distribution/routers/test_routing_tables.py index 3ba042bd9..30f795d33 100644 --- a/tests/unit/distribution/routers/test_routing_tables.py +++ b/tests/unit/distribution/routers/test_routing_tables.py @@ -11,17 +11,15 @@ from unittest.mock import AsyncMock from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.datasets.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datatypes import Api -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import Model from llama_stack.apis.shields.shields import Shield from llama_stack.apis.tools import ListToolDefsResponse, ToolDef, ToolGroup, ToolParameter -from llama_stack.apis.vector_dbs.vector_dbs import VectorDB from llama_stack.distribution.routing_tables.benchmarks import BenchmarksRoutingTable from llama_stack.distribution.routing_tables.datasets import DatasetsRoutingTable from llama_stack.distribution.routing_tables.models import ModelsRoutingTable from llama_stack.distribution.routing_tables.scoring_functions import ScoringFunctionsRoutingTable from llama_stack.distribution.routing_tables.shields import ShieldsRoutingTable from llama_stack.distribution.routing_tables.toolgroups import ToolGroupsRoutingTable -from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable class Impl: @@ -54,17 +52,6 @@ class SafetyImpl(Impl): return shield -class VectorDBImpl(Impl): - def __init__(self): - super().__init__(Api.vector_io) - - async def register_vector_db(self, vector_db: VectorDB): - return vector_db - - async def unregister_vector_db(self, vector_db_id: str): - return vector_db_id - - class DatasetsImpl(Impl): def __init__(self): super().__init__(Api.datasetio) @@ -173,36 +160,6 @@ async def test_shields_routing_table(cached_disk_dist_registry): assert "test-shield-2" in shield_ids -async def test_vectordbs_routing_table(cached_disk_dist_registry): - table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) - await table.initialize() - - m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) - await m_table.initialize() - await m_table.register_model( - model_id="test-model", - provider_id="test_provider", - metadata={"embedding_dimension": 128}, - model_type=ModelType.embedding, - ) - - # Register multiple vector databases and verify listing - await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") - await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") - vector_dbs = await table.list_vector_dbs() - - assert len(vector_dbs.data) == 2 - vector_db_ids = {v.identifier for v in vector_dbs.data} - assert "test-vectordb" in vector_db_ids - assert "test-vectordb-2" in vector_db_ids - - await table.unregister_vector_db(vector_db_id="test-vectordb") - await table.unregister_vector_db(vector_db_id="test-vectordb-2") - - vector_dbs = await table.list_vector_dbs() - assert len(vector_dbs.data) == 0 - - async def test_datasets_routing_table(cached_disk_dist_registry): table = DatasetsRoutingTable({"localfs": DatasetsImpl()}, cached_disk_dist_registry, {}) await table.initialize() diff --git a/tests/unit/distribution/routing_tables/__init__.py b/tests/unit/distribution/routing_tables/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/tests/unit/distribution/routing_tables/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/tests/unit/distribution/routing_tables/test_vector_dbs.py b/tests/unit/distribution/routing_tables/test_vector_dbs.py new file mode 100644 index 000000000..28887e1cf --- /dev/null +++ b/tests/unit/distribution/routing_tables/test_vector_dbs.py @@ -0,0 +1,274 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Unit tests for the routing tables vector_dbs + +import time +from unittest.mock import AsyncMock + +import pytest + +from llama_stack.apis.datatypes import Api +from llama_stack.apis.models import ModelType +from llama_stack.apis.vector_dbs.vector_dbs import VectorDB +from llama_stack.apis.vector_io.vector_io import ( + VectorStoreContent, + VectorStoreDeleteResponse, + VectorStoreFileContentsResponse, + VectorStoreFileCounts, + VectorStoreFileDeleteResponse, + VectorStoreFileObject, + VectorStoreObject, + VectorStoreSearchResponsePage, +) +from llama_stack.distribution.access_control.datatypes import AccessRule, Scope +from llama_stack.distribution.datatypes import User +from llama_stack.distribution.request_headers import request_provider_data_context +from llama_stack.distribution.routing_tables.vector_dbs import VectorDBsRoutingTable +from tests.unit.distribution.routers.test_routing_tables import Impl, InferenceImpl, ModelsRoutingTable + + +class VectorDBImpl(Impl): + def __init__(self): + super().__init__(Api.vector_io) + + async def register_vector_db(self, vector_db: VectorDB): + return vector_db + + async def unregister_vector_db(self, vector_db_id: str): + return vector_db_id + + async def openai_retrieve_vector_store(self, vector_store_id): + return VectorStoreObject( + id=vector_store_id, + name="Test Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_update_vector_store(self, vector_store_id, **kwargs): + return VectorStoreObject( + id=vector_store_id, + name="Updated Store", + created_at=int(time.time()), + file_counts=VectorStoreFileCounts(completed=0, cancelled=0, failed=0, in_progress=0, total=0), + ) + + async def openai_delete_vector_store(self, vector_store_id): + return VectorStoreDeleteResponse(id=vector_store_id, object="vector_store.deleted", deleted=True) + + async def openai_search_vector_store(self, vector_store_id, query, **kwargs): + return VectorStoreSearchResponsePage( + object="vector_store.search_results.page", search_query="query", data=[], has_more=False, next_page=None + ) + + async def openai_attach_file_to_vector_store(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_list_files_in_vector_store(self, vector_store_id, **kwargs): + return [ + VectorStoreFileObject( + id="1", + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + ] + + async def openai_retrieve_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_retrieve_vector_store_file_contents(self, vector_store_id, file_id): + return VectorStoreFileContentsResponse( + file_id=file_id, + filename="Sample File name", + attributes={"key": "value"}, + content=[VectorStoreContent(type="text", text="Sample content")], + ) + + async def openai_update_vector_store_file(self, vector_store_id, file_id, **kwargs): + return VectorStoreFileObject( + id=file_id, + status="completed", + chunking_strategy={"type": "auto"}, + created_at=int(time.time()), + vector_store_id=vector_store_id, + ) + + async def openai_delete_vector_store_file(self, vector_store_id, file_id): + return VectorStoreFileDeleteResponse(id=file_id, deleted=True) + + +async def test_vectordbs_routing_table(cached_disk_dist_registry): + table = VectorDBsRoutingTable({"test_provider": VectorDBImpl()}, cached_disk_dist_registry, {}) + await table.initialize() + + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, {}) + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + # Register multiple vector databases and verify listing + await table.register_vector_db(vector_db_id="test-vectordb", embedding_model="test-model") + await table.register_vector_db(vector_db_id="test-vectordb-2", embedding_model="test-model") + vector_dbs = await table.list_vector_dbs() + + assert len(vector_dbs.data) == 2 + vector_db_ids = {v.identifier for v in vector_dbs.data} + assert "test-vectordb" in vector_db_ids + assert "test-vectordb-2" in vector_db_ids + + await table.unregister_vector_db(vector_db_id="test-vectordb") + await table.unregister_vector_db(vector_db_id="test-vectordb-2") + + vector_dbs = await table.list_vector_dbs() + assert len(vector_dbs.data) == 0 + + +async def test_openai_vector_stores_routing_table_roles(cached_disk_dist_registry): + impl = VectorDBImpl() + impl.openai_retrieve_vector_store = AsyncMock(return_value="OK") + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=[]) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + authorized_table = "vs1" + authorized_team = "team1" + unauthorized_team = "team2" + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + authorized_user = User(principal="alice", attributes={"roles": [authorized_team]}) + with request_provider_data_context({}, authorized_user): + _ = await table.register_vector_db(vector_db_id="vs1", embedding_model="test-model") + + # Authorized reader + with request_provider_data_context({}, authorized_user): + res = await table.openai_retrieve_vector_store(authorized_table) + assert res == "OK" + + # Authorized updater + impl.openai_update_vector_store_file = AsyncMock(return_value="UPDATED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + assert res == "UPDATED" + + # Unauthorized reader + unauthorized_user = User(principal="eve", attributes={"roles": [unauthorized_team]}) + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_retrieve_vector_store(authorized_table) + + # Unauthorized updater + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_update_vector_store_file(authorized_table, file_id="file1", attributes={"foo": "bar"}) + + # Authorized deleter + impl.openai_delete_vector_store_file = AsyncMock(return_value="DELETED") + with request_provider_data_context({}, authorized_user): + res = await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + assert res == "DELETED" + + # Unauthorized deleter + with request_provider_data_context({}, unauthorized_user): + with pytest.raises(ValueError): + await table.openai_delete_vector_store_file(authorized_table, file_id="file1") + + +async def test_openai_vector_stores_routing_table_actions(cached_disk_dist_registry): + impl = VectorDBImpl() + + policy = [ + AccessRule(permit=Scope(actions=["create", "read", "update", "delete"]), when="user with admin in roles"), + AccessRule(permit=Scope(actions=["read"]), when="user with reader in roles"), + ] + + table = VectorDBsRoutingTable({"test_provider": impl}, cached_disk_dist_registry, policy=policy) + m_table = ModelsRoutingTable({"test_provider": InferenceImpl()}, cached_disk_dist_registry, policy=[]) + + vector_db_id = "vs1" + file_id = "file-1" + + admin_user = User(principal="admin", attributes={"roles": ["admin"]}) + read_only_user = User(principal="reader", attributes={"roles": ["reader"]}) + no_access_user = User(principal="outsider", attributes={"roles": ["no_access"]}) + + await m_table.initialize() + await m_table.register_model( + model_id="test-model", + provider_id="test_provider", + metadata={"embedding_dimension": 128}, + model_type=ModelType.embedding, + ) + + with request_provider_data_context({}, admin_user): + await table.register_vector_db(vector_db_id=vector_db_id, embedding_model="test-model") + + read_methods = [ + (table.openai_retrieve_vector_store, (vector_db_id,), {}), + (table.openai_search_vector_store, (vector_db_id, "query"), {}), + (table.openai_list_files_in_vector_store, (vector_db_id,), {}), + (table.openai_retrieve_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_retrieve_vector_store_file_contents, (vector_db_id, file_id), {}), + ] + update_methods = [ + (table.openai_update_vector_store, (vector_db_id,), {"name": "Updated DB"}), + (table.openai_attach_file_to_vector_store, (vector_db_id, file_id), {}), + (table.openai_update_vector_store_file, (vector_db_id, file_id), {"attributes": {"key": "value"}}), + ] + delete_methods = [ + (table.openai_delete_vector_store_file, (vector_db_id, file_id), {}), + (table.openai_delete_vector_store, (vector_db_id,), {}), + ] + + for user in [admin_user, read_only_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in read_methods: + result = await method(*args, **kwargs) + assert result is not None, f"Read operation failed with user {user.principal}" + + with request_provider_data_context({}, no_access_user): + for method, args, kwargs in read_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs) + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in update_methods: + result = await method(*args, **kwargs) + assert result is not None, "Update operation failed with admin user" + + with request_provider_data_context({}, admin_user): + for method, args, kwargs in delete_methods: + result = await method(*args, **kwargs) + assert result is not None, "Delete operation failed with admin user" + + for user in [read_only_user, no_access_user]: + with request_provider_data_context({}, user): + for method, args, kwargs in delete_methods: + with pytest.raises(ValueError): + await method(*args, **kwargs) diff --git a/tests/unit/providers/agent/test_get_raw_document_text.py b/tests/unit/providers/agent/test_get_raw_document_text.py new file mode 100644 index 000000000..eb481c0d8 --- /dev/null +++ b/tests/unit/providers/agent/test_get_raw_document_text.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import warnings +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from llama_stack.apis.agents import Document +from llama_stack.apis.common.content_types import URL, TextContentItem +from llama_stack.providers.inline.agents.meta_reference.agent_instance import get_raw_document_text + + +async def test_get_raw_document_text_supports_text_mime_types(): + """Test that the function accepts text/* mime types.""" + document = Document(content="Sample text content", mime_type="text/plain") + + result = await get_raw_document_text(document) + assert result == "Sample text content" + + +async def test_get_raw_document_text_supports_yaml_mime_type(): + """Test that the function accepts application/yaml mime type.""" + yaml_content = """ + name: test + version: 1.0 + items: + - item1 + - item2 + """ + + document = Document(content=yaml_content, mime_type="application/yaml") + + result = await get_raw_document_text(document) + assert result == yaml_content + + +async def test_get_raw_document_text_supports_deprecated_text_yaml_with_warning(): + """Test that the function accepts text/yaml but emits a deprecation warning.""" + yaml_content = """ + name: test + version: 1.0 + items: + - item1 + - item2 + """ + + document = Document(content=yaml_content, mime_type="text/yaml") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await get_raw_document_text(document) + + # Check that result is correct + assert result == yaml_content + + # Check that exactly one warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "text/yaml" in str(w[0].message) + assert "application/yaml" in str(w[0].message) + assert "deprecated" in str(w[0].message).lower() + + +async def test_get_raw_document_text_deprecated_text_yaml_with_url(): + """Test that text/yaml works with URL content and emits warning.""" + yaml_content = "name: test\nversion: 1.0" + + with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load: + mock_load.return_value = yaml_content + + document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="text/yaml") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await get_raw_document_text(document) + + # Check that result is correct + assert result == yaml_content + mock_load.assert_called_once_with("https://example.com/config.yaml") + + # Check that deprecation warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "text/yaml" in str(w[0].message) + + +async def test_get_raw_document_text_deprecated_text_yaml_with_text_content_item(): + """Test that text/yaml works with TextContentItem and emits warning.""" + yaml_content = "key: value\nlist:\n - item1\n - item2" + + document = Document(content=TextContentItem(text=yaml_content), mime_type="text/yaml") + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + result = await get_raw_document_text(document) + + # Check that result is correct + assert result == yaml_content + + # Check that deprecation warning was issued + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert "text/yaml" in str(w[0].message) + + +async def test_get_raw_document_text_rejects_unsupported_mime_types(): + """Test that the function rejects unsupported mime types.""" + document = Document( + content="Some content", + mime_type="application/json", # Not supported + ) + + with pytest.raises(ValueError, match="Unexpected document mime type: application/json"): + await get_raw_document_text(document) + + +async def test_get_raw_document_text_with_url_content(): + """Test that the function handles URL content correctly.""" + mock_response = AsyncMock() + mock_response.text = "Content from URL" + + with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load: + mock_load.return_value = "Content from URL" + + document = Document(content=URL(uri="https://example.com/test.txt"), mime_type="text/plain") + + result = await get_raw_document_text(document) + assert result == "Content from URL" + mock_load.assert_called_once_with("https://example.com/test.txt") + + +async def test_get_raw_document_text_with_yaml_url(): + """Test that the function handles YAML URLs correctly.""" + yaml_content = "name: test\nversion: 1.0" + + with patch("llama_stack.providers.inline.agents.meta_reference.agent_instance.load_data_from_url") as mock_load: + mock_load.return_value = yaml_content + + document = Document(content=URL(uri="https://example.com/config.yaml"), mime_type="application/yaml") + + result = await get_raw_document_text(document) + assert result == yaml_content + mock_load.assert_called_once_with("https://example.com/config.yaml") + + +async def test_get_raw_document_text_with_text_content_item(): + """Test that the function handles TextContentItem correctly.""" + document = Document(content=TextContentItem(text="Text content item"), mime_type="text/plain") + + result = await get_raw_document_text(document) + assert result == "Text content item" + + +async def test_get_raw_document_text_with_yaml_text_content_item(): + """Test that the function handles YAML TextContentItem correctly.""" + yaml_content = "key: value\nlist:\n - item1\n - item2" + + document = Document(content=TextContentItem(text=yaml_content), mime_type="application/yaml") + + result = await get_raw_document_text(document) + assert result == yaml_content + + +async def test_get_raw_document_text_rejects_unexpected_content_type(): + """Test that the function rejects unexpected document content types.""" + # Create a mock document that bypasses Pydantic validation + mock_document = MagicMock(spec=Document) + mock_document.mime_type = "text/plain" + mock_document.content = 123 # Unexpected content type (not str, URL, or TextContentItem) + + with pytest.raises(ValueError, match="Unexpected document content type: "): + await get_raw_document_text(mock_document)