From 706b4ca651928941ec778bccbf1b9f5751025b79 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Tue, 18 Mar 2025 13:54:10 -0500 Subject: [PATCH 01/19] feat: support nvidia hosted vision models (llama 3.2 11b/90b) (#1278) # What does this PR do? support nvidia hosted 3.2 11b/90b vision models. they are not hosted on the common https://integrate.api.nvidia.com/v1. they are hosted on their own individual urls. ## Test Plan `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -s -v tests/client-sdk/inference/test_vision_inference.py --inference-model=meta/llama-3.2-11b-vision-instruct -k image` --- .../remote/inference/nvidia/nvidia.py | 51 +++++++++++++++---- 1 file changed, 42 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index b59da79eb..69e6335c6 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -6,6 +6,7 @@ import logging import warnings +from functools import lru_cache from typing import AsyncIterator, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -82,12 +83,42 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): # ) self._config = config - # make sure the client lives longer than any async calls - self._client = AsyncOpenAI( - base_url=f"{self._config.url}/v1", - api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), - timeout=self._config.timeout, - ) + + @lru_cache # noqa: B019 + def _get_client(self, provider_model_id: str) -> AsyncOpenAI: + """ + For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However, + some models are hosted on different URLs. This function returns the appropriate client + for the given provider_model_id. + + This relies on lru_cache and self._default_client to avoid creating a new client for each request + or for each model that is hosted on https://integrate.api.nvidia.com/v1. + + :param provider_model_id: The provider model ID + :return: An OpenAI client + """ + + @lru_cache # noqa: B019 + def _get_client_for_base_url(base_url: str) -> AsyncOpenAI: + """ + Maintain a single OpenAI client per base_url. + """ + return AsyncOpenAI( + base_url=base_url, + api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), + timeout=self._config.timeout, + ) + + special_model_urls = { + "meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct", + "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", + } + + base_url = f"{self._config.url}/v1" + if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: + base_url = special_model_urls[provider_model_id] + + return _get_client_for_base_url(base_url) async def completion( self, @@ -105,9 +136,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): await check_health(self._config) # this raises errors + provider_model_id = self.get_provider_model_id(model_id) request = convert_completion_request( request=CompletionRequest( - model=self.get_provider_model_id(model_id), + model=provider_model_id, content=content, sampling_params=sampling_params, response_format=response_format, @@ -118,7 +150,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.completions.create(**request) + response = await self._get_client(provider_model_id).completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -206,6 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): await check_health(self._config) # this raises errors + provider_model_id = self.get_provider_model_id(model_id) request = await convert_chat_completion_request( request=ChatCompletionRequest( model=self.get_provider_model_id(model_id), @@ -221,7 +254,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._client.chat.completions.create(**request) + response = await self._get_client(provider_model_id).chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e From 814eb753216bf1e12c26a6e9267449e3fc167dbf Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 18 Mar 2025 15:17:21 -0400 Subject: [PATCH 02/19] chore: enable ruff for ./scripts too (#1643) # What does this PR do? Enable ruff for scripts. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka --- pyproject.toml | 1 - scripts/gen-changelog.py | 8 +++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a006d69f9..7972bf211 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -114,7 +114,6 @@ exclude = [ "./.git", "./docs/*", "./build", - "./scripts", "./venv", "*.pyi", ".pre-commit-config.yaml", diff --git a/scripts/gen-changelog.py b/scripts/gen-changelog.py index ac4053339..3df2af06b 100755 --- a/scripts/gen-changelog.py +++ b/scripts/gen-changelog.py @@ -11,7 +11,7 @@ import requests def get_all_releases(token): - url = f"https://api.github.com/repos/meta-llama/llama-stack/releases" + url = "https://api.github.com/repos/meta-llama/llama-stack/releases" headers = {"Accept": "application/vnd.github.v3+json"} if token: @@ -22,9 +22,7 @@ def get_all_releases(token): if response.status_code == 200: return response.json() else: - raise Exception( - f"Error fetching releases: {response.status_code}, {response.text}" - ) + raise Exception(f"Error fetching releases: {response.status_code}, {response.text}") def clean_release_body(body): @@ -55,7 +53,7 @@ def merge_release_notes(output_file, token=None): releases = get_all_releases(token) with open(output_file, "w", encoding="utf-8") as md_file: - md_file.write(f"# Changelog\n\n") + md_file.write("# Changelog\n\n") for release in releases: md_file.write(f"# {release['tag_name']}\n") From 141b3c14dd23822117e8136d6ab466ca09170e3a Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Tue, 18 Mar 2025 16:39:46 -0400 Subject: [PATCH 03/19] docs: fix broken test path in CONTRIBUTING.md (#1679) # What does this PR do? fix broken test path in CONTRIBUTING.md Signed-off-by: Nathan Weinberg --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index e458fec0a..505d6b162 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -86,7 +86,7 @@ LLAMA_STACK_CONFIG= And then use this dotenv file when running client SDK tests via the following: ```bash -uv run --env-file .env -- pytest -v tests/api/inference/test_text_inference.py +uv run --env-file .env -- pytest -v tests/integration/inference/test_text_inference.py ``` ## Pre-commit Hooks From cca9bd6cc3d0fc3046b91bd6fecf9b4a85f466cf Mon Sep 17 00:00:00 2001 From: Daniele Martinoli <86618610+dmartinol@users.noreply.github.com> Date: Tue, 18 Mar 2025 22:04:21 +0100 Subject: [PATCH 04/19] feat: Qdrant inline provider (#1273) # What does this PR do? Removed local execution option from the remote Qdrant provider and introduced an explicit inline provider for the embedded execution. Updated the ollama template to include this option: this part can be reverted in case we don't want to have two default `vector_io` providers. (Closes #1082) ## Test Plan Build and run an ollama distro: ```bash llama stack build --template ollama --image-type conda llama stack run --image-type conda ollama ``` Run one of the sample ingestionapplicatinos like [rag_with_vector_db.py](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/rag_with_vector_db.py), but replace this line: ```py selected_vector_provider = vector_providers[0] ``` with the following, to use the `qdrant` provider: ```py selected_vector_provider = vector_providers[1] ``` After running the test code, verify the timestamp of the Qdrant store: ```bash % ls -ltr ~/.llama/distributions/ollama/qdrant.db/collection/test_vector_db_* total 784 -rw-r--r--@ 1 dmartino staff 401408 Feb 26 10:07 storage.sqlite ``` [//]: # (## Documentation) --------- Signed-off-by: Daniele Martinoli Co-authored-by: Francisco Arceo --- docs/source/providers/vector_io/qdrant.md | 21 +- .../inline/vector_io/qdrant/__init__.py | 19 ++ .../inline/vector_io/qdrant/config.py | 23 ++ llama_stack/providers/registry/vector_io.py | 8 + .../remote/vector_io/qdrant/config.py | 1 - .../remote/vector_io/qdrant/qdrant.py | 20 +- pyproject.toml | 3 +- tests/unit/providers/vector_io/conftest.py | 42 ++++ tests/unit/providers/vector_io/test_qdrant.py | 135 ++++++++++++ .../providers/vector_io/test_sqlite_vec.py | 32 +-- uv.lock | 198 +++++++++++++++++- 11 files changed, 454 insertions(+), 48 deletions(-) create mode 100644 llama_stack/providers/inline/vector_io/qdrant/__init__.py create mode 100644 llama_stack/providers/inline/vector_io/qdrant/config.py create mode 100644 tests/unit/providers/vector_io/conftest.py create mode 100644 tests/unit/providers/vector_io/test_qdrant.py diff --git a/docs/source/providers/vector_io/qdrant.md b/docs/source/providers/vector_io/qdrant.md index a0de0be98..8b0cbeef8 100644 --- a/docs/source/providers/vector_io/qdrant.md +++ b/docs/source/providers/vector_io/qdrant.md @@ -3,21 +3,36 @@ orphan: true --- # Qdrant -[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It +[Qdrant](https://qdrant.tech/documentation/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly in memory. That means you'll get fast and efficient vector retrieval. +> By default, Qdrant stores vectors in RAM, delivering incredibly fast access for datasets that fit comfortably in +> memory. But when your dataset exceeds RAM capacity, Qdrant offers Memmap as an alternative. +> +> \[[An Introduction to Vector Databases](https://qdrant.tech/articles/what-is-a-vector-database/)\] + + + ## Features -- Easy to use +- Lightweight and easy to use - Fully integrated with Llama Stack +- Apache 2.0 license terms +- Store embeddings and their metadata +- Supports search by + [Keyword](https://qdrant.tech/articles/qdrant-introduces-full-text-filters-and-indexes/) + and [Hybrid](https://qdrant.tech/articles/hybrid-search/#building-a-hybrid-search-system-in-qdrant) search +- [Multilingual and Multimodal retrieval](https://qdrant.tech/documentation/multimodal-search/) +- [Medatata filtering](https://qdrant.tech/articles/vector-search-filtering/) +- [GPU support](https://qdrant.tech/documentation/guides/running-with-gpu/) ## Usage To use Qdrant in your Llama Stack project, follow these steps: 1. Install the necessary dependencies. -2. Configure your Llama Stack project to use Faiss. +2. Configure your Llama Stack project to use Qdrant. 3. Start storing and querying vectors. ## Installation diff --git a/llama_stack/providers/inline/vector_io/qdrant/__init__.py b/llama_stack/providers/inline/vector_io/qdrant/__init__.py new file mode 100644 index 000000000..8f0b91c61 --- /dev/null +++ b/llama_stack/providers/inline/vector_io/qdrant/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict + +from llama_stack.providers.datatypes import Api, ProviderSpec + +from .config import QdrantVectorIOConfig + + +async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from llama_stack.providers.remote.vector_io.qdrant.qdrant import QdrantVectorIOAdapter + + impl = QdrantVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/vector_io/qdrant/config.py b/llama_stack/providers/inline/vector_io/qdrant/config.py new file mode 100644 index 000000000..282e951b0 --- /dev/null +++ b/llama_stack/providers/inline/vector_io/qdrant/config.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from typing import Any, Dict + +from pydantic import BaseModel + +from llama_stack.schema_utils import json_schema_type + + +@json_schema_type +class QdrantVectorIOConfig(BaseModel): + path: str + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + return { + "path": "${env.QDRANT_PATH:~/.llama/" + __distro_dir__ + "}/" + "qdrant.db", + } diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index fbc495d83..93031763d 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -92,6 +92,14 @@ def available_providers() -> List[ProviderSpec]: ), api_dependencies=[Api.inference], ), + InlineProviderSpec( + api=Api.vector_io, + provider_type="inline::qdrant", + pip_packages=["qdrant-client"], + module="llama_stack.providers.inline.vector_io.qdrant", + config_class="llama_stack.providers.inline.vector_io.qdrant.QdrantVectorIOConfig", + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index ce68aa492..6d7eebe23 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -23,7 +23,6 @@ class QdrantVectorIOConfig(BaseModel): prefix: Optional[str] = None timeout: Optional[int] = None host: Optional[str] = None - path: Optional[str] = None @classmethod def sample_run_config(cls, **kwargs: Any) -> Dict[str, Any]: diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index 586b8ca95..9e7788dc0 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -6,7 +6,7 @@ import logging import uuid -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from numpy.typing import NDArray from qdrant_client import AsyncQdrantClient, models @@ -16,12 +16,13 @@ from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.inline.vector_io.qdrant import QdrantVectorIOConfig as InlineQdrantVectorIOConfig from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) -from .config import QdrantVectorIOConfig +from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" @@ -99,17 +100,19 @@ class QdrantIndex(EmbeddingIndex): class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: QdrantVectorIOConfig, inference_api: Api.inference) -> None: + def __init__( + self, config: Union[RemoteQdrantVectorIOConfig, InlineQdrantVectorIOConfig], inference_api: Api.inference + ) -> None: self.config = config - self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) + self.client: AsyncQdrantClient = None self.cache = {} self.inference_api = inference_api async def initialize(self) -> None: - pass + self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) async def shutdown(self) -> None: - self.client.close() + await self.client.close() async def register_vector_db( self, @@ -123,6 +126,11 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache[vector_db.identifier] = index + async def unregister_vector_db(self, vector_db_id: str) -> None: + if vector_db_id in self.cache: + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> Optional[VectorDBWithIndex]: if vector_db_id in self.cache: return self.cache[vector_db_id] diff --git a/pyproject.toml b/pyproject.toml index 7972bf211..f57b91462 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ dev = [ "ruamel.yaml", # needed for openapi generator ] # These are the dependencies required for running unit tests. -unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet"] +unit = ["sqlite-vec", "openai", "aiosqlite", "pypdf", "chardet", "qdrant-client"] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment # separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra @@ -247,6 +247,7 @@ exclude = [ "^llama_stack/providers/inline/vector_io/chroma/", "^llama_stack/providers/inline/vector_io/faiss/", "^llama_stack/providers/inline/vector_io/milvus/", + "^llama_stack/providers/inline/vector_io/qdrant/", "^llama_stack/providers/inline/vector_io/sqlite_vec/", "^llama_stack/providers/remote/agents/sample/", "^llama_stack/providers/remote/datasetio/huggingface/", diff --git a/tests/unit/providers/vector_io/conftest.py b/tests/unit/providers/vector_io/conftest.py new file mode 100644 index 000000000..3bcd0613f --- /dev/null +++ b/tests/unit/providers/vector_io/conftest.py @@ -0,0 +1,42 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import random + +import numpy as np +import pytest + +from llama_stack.apis.vector_io import Chunk + +EMBEDDING_DIMENSION = 384 + + +@pytest.fixture +def vector_db_id() -> str: + return f"test-vector-db-{random.randint(1, 100)}" + + +@pytest.fixture(scope="session") +def embedding_dimension() -> int: + return EMBEDDING_DIMENSION + + +@pytest.fixture(scope="session") +def sample_chunks(): + """Generates chunks that force multiple batches for a single document to expose ID conflicts.""" + n, k = 10, 3 + sample = [ + Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"}) + for j in range(k) + for i in range(n) + ] + return sample + + +@pytest.fixture(scope="session") +def sample_embeddings(sample_chunks): + np.random.seed(42) + return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks]) diff --git a/tests/unit/providers/vector_io/test_qdrant.py b/tests/unit/providers/vector_io/test_qdrant.py new file mode 100644 index 000000000..bc97719c0 --- /dev/null +++ b/tests/unit/providers/vector_io/test_qdrant.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio +import os +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + +from llama_stack.apis.inference import EmbeddingsResponse, Inference +from llama_stack.apis.vector_io import ( + QueryChunksResponse, + VectorDB, + VectorDBStore, +) +from llama_stack.providers.inline.vector_io.qdrant.config import ( + QdrantVectorIOConfig as InlineQdrantVectorIOConfig, +) +from llama_stack.providers.remote.vector_io.qdrant.qdrant import ( + QdrantVectorIOAdapter, +) + +# This test is a unit test for the QdrantVectorIOAdapter class. This should only contain +# tests which are specific to this class. More general (API-level) tests should be placed in +# tests/integration/vector_io/ +# +# How to run this test: +# +# pytest tests/unit/providers/vector_io/test_qdrant.py \ +# -v -s --tb=short --disable-warnings --asyncio-mode=auto + + +@pytest.fixture +def qdrant_config(tmp_path) -> InlineQdrantVectorIOConfig: + return InlineQdrantVectorIOConfig(path=os.path.join(tmp_path, "qdrant.db")) + + +@pytest.fixture(scope="session") +def loop(): + return asyncio.new_event_loop() + + +@pytest.fixture +def mock_vector_db(vector_db_id) -> MagicMock: + mock_vector_db = MagicMock(spec=VectorDB) + mock_vector_db.embedding_model = "embedding_model" + mock_vector_db.identifier = vector_db_id + return mock_vector_db + + +@pytest.fixture +def mock_vector_db_store(mock_vector_db) -> MagicMock: + mock_store = MagicMock(spec=VectorDBStore) + mock_store.get_vector_db = AsyncMock(return_value=mock_vector_db) + return mock_store + + +@pytest.fixture +def mock_api_service(sample_embeddings): + mock_api_service = MagicMock(spec=Inference) + mock_api_service.embeddings = AsyncMock(return_value=EmbeddingsResponse(embeddings=sample_embeddings)) + return mock_api_service + + +@pytest_asyncio.fixture +async def qdrant_adapter(qdrant_config, mock_vector_db_store, mock_api_service, loop) -> QdrantVectorIOAdapter: + adapter = QdrantVectorIOAdapter(config=qdrant_config, inference_api=mock_api_service) + adapter.vector_db_store = mock_vector_db_store + await adapter.initialize() + yield adapter + await adapter.shutdown() + + +__QUERY = "Sample query" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("max_query_chunks, expected_chunks", [(2, 2), (100, 30)]) +async def test_qdrant_adapter_returns_expected_chunks( + qdrant_adapter: QdrantVectorIOAdapter, + vector_db_id, + sample_chunks, + sample_embeddings, + max_query_chunks, + expected_chunks, +) -> None: + assert qdrant_adapter is not None + await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks) + + index = await qdrant_adapter._get_and_cache_vector_db_index(vector_db_id=vector_db_id) + assert index is not None + + response = await qdrant_adapter.query_chunks( + query=__QUERY, + vector_db_id=vector_db_id, + params={"max_chunks": max_query_chunks}, + ) + assert isinstance(response, QueryChunksResponse) + assert len(response.chunks) == expected_chunks + + +# To by-pass attempt to convert a Mock to JSON +def _prepare_for_json(value: Any) -> str: + return str(value) + + +@patch("llama_stack.providers.utils.telemetry.trace_protocol._prepare_for_json", new=_prepare_for_json) +@pytest.mark.asyncio +async def test_qdrant_register_and_unregister_vector_db( + qdrant_adapter: QdrantVectorIOAdapter, + mock_vector_db, + sample_chunks, +) -> None: + # Initially, no collections + vector_db_id = mock_vector_db.identifier + assert len((await qdrant_adapter.client.get_collections()).collections) == 0 + + # Register does not create a collection + assert not (await qdrant_adapter.client.collection_exists(vector_db_id)) + await qdrant_adapter.register_vector_db(mock_vector_db) + assert not (await qdrant_adapter.client.collection_exists(vector_db_id)) + + # First insert creates the collection + await qdrant_adapter.insert_chunks(vector_db_id, sample_chunks) + assert await qdrant_adapter.client.collection_exists(vector_db_id) + + # Unregister deletes the collection + await qdrant_adapter.unregister_vector_db(vector_db_id) + assert not (await qdrant_adapter.client.collection_exists(vector_db_id)) + assert len((await qdrant_adapter.client.get_collections()).collections) == 0 diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index eb5660a85..cff988c53 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -29,8 +29,6 @@ from llama_stack.providers.inline.vector_io.sqlite_vec.sqlite_vec import ( # -v -s --tb=short --disable-warnings --asyncio-mode=auto SQLITE_VEC_PROVIDER = "sqlite_vec" -EMBEDDING_DIMENSION = 384 -EMBEDDING_MODEL = "all-MiniLM-L6-v2" @pytest.fixture(scope="session") @@ -50,26 +48,8 @@ def sqlite_connection(loop): @pytest_asyncio.fixture(scope="session", autouse=True) -async def sqlite_vec_index(sqlite_connection): - return await SQLiteVecIndex.create(dimension=EMBEDDING_DIMENSION, connection=sqlite_connection, bank_id="test_bank") - - -@pytest.fixture(scope="session") -def sample_chunks(): - """Generates chunks that force multiple batches for a single document to expose ID conflicts.""" - n, k = 10, 3 - sample = [ - Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"}) - for j in range(k) - for i in range(n) - ] - return sample - - -@pytest.fixture(scope="session") -def sample_embeddings(sample_chunks): - np.random.seed(42) - return np.array([np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) for _ in sample_chunks]) +async def sqlite_vec_index(sqlite_connection, embedding_dimension): + return await SQLiteVecIndex.create(dimension=embedding_dimension, connection=sqlite_connection, bank_id="test_bank") @pytest.mark.asyncio @@ -82,21 +62,21 @@ async def test_add_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): @pytest.mark.asyncio -async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings): +async def test_query_chunks(sqlite_vec_index, sample_chunks, sample_embeddings, embedding_dimension): await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings) - query_embedding = np.random.rand(EMBEDDING_DIMENSION).astype(np.float32) + query_embedding = np.random.rand(embedding_dimension).astype(np.float32) response = await sqlite_vec_index.query(query_embedding, k=2, score_threshold=0.0) assert isinstance(response, QueryChunksResponse) assert len(response.chunks) == 2 @pytest.mark.asyncio -async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks): +async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dimension): """Test that chunk IDs do not conflict across batches when inserting chunks.""" # Reduce batch size to force multiple batches for same document # since there are 10 chunks per document and batch size is 2 batch_size = 2 - sample_embeddings = np.random.rand(len(sample_chunks), EMBEDDING_DIMENSION).astype(np.float32) + sample_embeddings = np.random.rand(len(sample_chunks), embedding_dimension).astype(np.float32) await sqlite_vec_index.add_chunks(sample_chunks, sample_embeddings, batch_size=batch_size) diff --git a/uv.lock b/uv.lock index 860b29241..b63d23b14 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -8,9 +7,12 @@ resolution-markers = [ "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform == 'darwin'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", + "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.13' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", ] [[package]] @@ -793,6 +795,107 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/30/2bd0eb03a7dee7727cd2ec643d1e992979e62d5e7443507381cce0455132/googleapis_common_protos-1.67.0-py2.py3-none-any.whl", hash = "sha256:579de760800d13616f51cf8be00c876f00a9f146d3e6510e19d1f4111758b741", size = 164985 }, ] +[[package]] +name = "grpcio" +version = "1.71.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1c/95/aa11fc09a85d91fbc7dd405dcb2a1e0256989d67bf89fa65ae24b3ba105a/grpcio-1.71.0.tar.gz", hash = "sha256:2b85f7820475ad3edec209d3d89a7909ada16caab05d3f2e08a7e8ae3200a55c", size = 12549828 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/7c/c5/ef610b3f988cc0cc67b765f72b8e2db06a1db14e65acb5ae7810a6b7042e/grpcio-1.71.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:c200cb6f2393468142eb50ab19613229dcc7829b5ccee8b658a36005f6669fdd", size = 5210643 }, + { url = "https://files.pythonhosted.org/packages/bf/de/c84293c961622df302c0d5d07ec6e2d4cd3874ea42f602be2df09c4ad44f/grpcio-1.71.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:b2266862c5ad664a380fbbcdbdb8289d71464c42a8c29053820ee78ba0119e5d", size = 11308962 }, + { url = "https://files.pythonhosted.org/packages/7c/38/04c9e0dc8c904570c80faa1f1349b190b63e45d6b2782ec8567b050efa9d/grpcio-1.71.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:0ab8b2864396663a5b0b0d6d79495657ae85fa37dcb6498a2669d067c65c11ea", size = 5699236 }, + { url = "https://files.pythonhosted.org/packages/95/96/e7be331d1298fa605ea7c9ceafc931490edd3d5b33c4f695f1a0667f3491/grpcio-1.71.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c30f393f9d5ff00a71bb56de4aa75b8fe91b161aeb61d39528db6b768d7eac69", size = 6339767 }, + { url = "https://files.pythonhosted.org/packages/5d/b7/7e7b7bb6bb18baf156fd4f2f5b254150dcdd6cbf0def1ee427a2fb2bfc4d/grpcio-1.71.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f250ff44843d9a0615e350c77f890082102a0318d66a99540f54769c8766ab73", size = 5943028 }, + { url = "https://files.pythonhosted.org/packages/13/aa/5fb756175995aeb47238d706530772d9a7ac8e73bcca1b47dc145d02c95f/grpcio-1.71.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e6d8de076528f7c43a2f576bc311799f89d795aa6c9b637377cc2b1616473804", size = 6031841 }, + { url = "https://files.pythonhosted.org/packages/54/93/172783e01eed61f7f180617b7fa4470f504e383e32af2587f664576a7101/grpcio-1.71.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9b91879d6da1605811ebc60d21ab6a7e4bae6c35f6b63a061d61eb818c8168f6", size = 6651039 }, + { url = "https://files.pythonhosted.org/packages/6f/99/62654b220a27ed46d3313252214f4bc66261143dc9b58004085cd0646753/grpcio-1.71.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f71574afdf944e6652203cd1badcda195b2a27d9c83e6d88dc1ce3cfb73b31a5", size = 6198465 }, + { url = "https://files.pythonhosted.org/packages/68/35/96116de833b330abe4412cc94edc68f99ed2fa3e39d8713ff307b3799e81/grpcio-1.71.0-cp310-cp310-win32.whl", hash = "sha256:8997d6785e93308f277884ee6899ba63baafa0dfb4729748200fcc537858a509", size = 3620382 }, + { url = "https://files.pythonhosted.org/packages/b7/09/f32ef637e386f3f2c02effac49699229fa560ce9007682d24e9e212d2eb4/grpcio-1.71.0-cp310-cp310-win_amd64.whl", hash = "sha256:7d6ac9481d9d0d129224f6d5934d5832c4b1cddb96b59e7eba8416868909786a", size = 4280302 }, + { url = "https://files.pythonhosted.org/packages/63/04/a085f3ad4133426f6da8c1becf0749872a49feb625a407a2e864ded3fb12/grpcio-1.71.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:d6aa986318c36508dc1d5001a3ff169a15b99b9f96ef5e98e13522c506b37eef", size = 5210453 }, + { url = "https://files.pythonhosted.org/packages/b4/d5/0bc53ed33ba458de95020970e2c22aa8027b26cc84f98bea7fcad5d695d1/grpcio-1.71.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:d2c170247315f2d7e5798a22358e982ad6eeb68fa20cf7a820bb74c11f0736e7", size = 11347567 }, + { url = "https://files.pythonhosted.org/packages/e3/6d/ce334f7e7a58572335ccd61154d808fe681a4c5e951f8a1ff68f5a6e47ce/grpcio-1.71.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:e6f83a583ed0a5b08c5bc7a3fe860bb3c2eac1f03f1f63e0bc2091325605d2b7", size = 5696067 }, + { url = "https://files.pythonhosted.org/packages/05/4a/80befd0b8b1dc2b9ac5337e57473354d81be938f87132e147c4a24a581bd/grpcio-1.71.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4be74ddeeb92cc87190e0e376dbc8fc7736dbb6d3d454f2fa1f5be1dee26b9d7", size = 6348377 }, + { url = "https://files.pythonhosted.org/packages/c7/67/cbd63c485051eb78663355d9efd1b896cfb50d4a220581ec2cb9a15cd750/grpcio-1.71.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4dd0dfbe4d5eb1fcfec9490ca13f82b089a309dc3678e2edabc144051270a66e", size = 5940407 }, + { url = "https://files.pythonhosted.org/packages/98/4b/7a11aa4326d7faa499f764eaf8a9b5a0eb054ce0988ee7ca34897c2b02ae/grpcio-1.71.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:a2242d6950dc892afdf9e951ed7ff89473aaf744b7d5727ad56bdaace363722b", size = 6030915 }, + { url = "https://files.pythonhosted.org/packages/eb/a2/cdae2d0e458b475213a011078b0090f7a1d87f9a68c678b76f6af7c6ac8c/grpcio-1.71.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0fa05ee31a20456b13ae49ad2e5d585265f71dd19fbd9ef983c28f926d45d0a7", size = 6648324 }, + { url = "https://files.pythonhosted.org/packages/27/df/f345c8daaa8d8574ce9869f9b36ca220c8845923eb3087e8f317eabfc2a8/grpcio-1.71.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:3d081e859fb1ebe176de33fc3adb26c7d46b8812f906042705346b314bde32c3", size = 6197839 }, + { url = "https://files.pythonhosted.org/packages/f2/2c/cd488dc52a1d0ae1bad88b0d203bc302efbb88b82691039a6d85241c5781/grpcio-1.71.0-cp311-cp311-win32.whl", hash = "sha256:d6de81c9c00c8a23047136b11794b3584cdc1460ed7cbc10eada50614baa1444", size = 3619978 }, + { url = "https://files.pythonhosted.org/packages/ee/3f/cf92e7e62ccb8dbdf977499547dfc27133124d6467d3a7d23775bcecb0f9/grpcio-1.71.0-cp311-cp311-win_amd64.whl", hash = "sha256:24e867651fc67717b6f896d5f0cac0ec863a8b5fb7d6441c2ab428f52c651c6b", size = 4282279 }, + { url = "https://files.pythonhosted.org/packages/4c/83/bd4b6a9ba07825bd19c711d8b25874cd5de72c2a3fbf635c3c344ae65bd2/grpcio-1.71.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:0ff35c8d807c1c7531d3002be03221ff9ae15712b53ab46e2a0b4bb271f38537", size = 5184101 }, + { url = "https://files.pythonhosted.org/packages/31/ea/2e0d90c0853568bf714693447f5c73272ea95ee8dad107807fde740e595d/grpcio-1.71.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:b78a99cd1ece4be92ab7c07765a0b038194ded2e0a26fd654591ee136088d8d7", size = 11310927 }, + { url = "https://files.pythonhosted.org/packages/ac/bc/07a3fd8af80467390af491d7dc66882db43884128cdb3cc8524915e0023c/grpcio-1.71.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:dc1a1231ed23caac1de9f943d031f1bc38d0f69d2a3b243ea0d664fc1fbd7fec", size = 5654280 }, + { url = "https://files.pythonhosted.org/packages/16/af/21f22ea3eed3d0538b6ef7889fce1878a8ba4164497f9e07385733391e2b/grpcio-1.71.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e6beeea5566092c5e3c4896c6d1d307fb46b1d4bdf3e70c8340b190a69198594", size = 6312051 }, + { url = "https://files.pythonhosted.org/packages/49/9d/e12ddc726dc8bd1aa6cba67c85ce42a12ba5b9dd75d5042214a59ccf28ce/grpcio-1.71.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5170929109450a2c031cfe87d6716f2fae39695ad5335d9106ae88cc32dc84c", size = 5910666 }, + { url = "https://files.pythonhosted.org/packages/d9/e9/38713d6d67aedef738b815763c25f092e0454dc58e77b1d2a51c9d5b3325/grpcio-1.71.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:5b08d03ace7aca7b2fadd4baf291139b4a5f058805a8327bfe9aece7253b6d67", size = 6012019 }, + { url = "https://files.pythonhosted.org/packages/80/da/4813cd7adbae6467724fa46c952d7aeac5e82e550b1c62ed2aeb78d444ae/grpcio-1.71.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:f903017db76bf9cc2b2d8bdd37bf04b505bbccad6be8a81e1542206875d0e9db", size = 6637043 }, + { url = "https://files.pythonhosted.org/packages/52/ca/c0d767082e39dccb7985c73ab4cf1d23ce8613387149e9978c70c3bf3b07/grpcio-1.71.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:469f42a0b410883185eab4689060a20488a1a0a00f8bbb3cbc1061197b4c5a79", size = 6186143 }, + { url = "https://files.pythonhosted.org/packages/00/61/7b2c8ec13303f8fe36832c13d91ad4d4ba57204b1c723ada709c346b2271/grpcio-1.71.0-cp312-cp312-win32.whl", hash = "sha256:ad9f30838550695b5eb302add33f21f7301b882937460dd24f24b3cc5a95067a", size = 3604083 }, + { url = "https://files.pythonhosted.org/packages/fd/7c/1e429c5fb26122055d10ff9a1d754790fb067d83c633ff69eddcf8e3614b/grpcio-1.71.0-cp312-cp312-win_amd64.whl", hash = "sha256:652350609332de6dac4ece254e5d7e1ff834e203d6afb769601f286886f6f3a8", size = 4272191 }, + { url = "https://files.pythonhosted.org/packages/04/dd/b00cbb45400d06b26126dcfdbdb34bb6c4f28c3ebbd7aea8228679103ef6/grpcio-1.71.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:cebc1b34ba40a312ab480ccdb396ff3c529377a2fce72c45a741f7215bfe8379", size = 5184138 }, + { url = "https://files.pythonhosted.org/packages/ed/0a/4651215983d590ef53aac40ba0e29dda941a02b097892c44fa3357e706e5/grpcio-1.71.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:85da336e3649a3d2171e82f696b5cad2c6231fdd5bad52616476235681bee5b3", size = 11310747 }, + { url = "https://files.pythonhosted.org/packages/57/a3/149615b247f321e13f60aa512d3509d4215173bdb982c9098d78484de216/grpcio-1.71.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f9a412f55bb6e8f3bb000e020dbc1e709627dcb3a56f6431fa7076b4c1aab0db", size = 5653991 }, + { url = "https://files.pythonhosted.org/packages/ca/56/29432a3e8d951b5e4e520a40cd93bebaa824a14033ea8e65b0ece1da6167/grpcio-1.71.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:47be9584729534660416f6d2a3108aaeac1122f6b5bdbf9fd823e11fe6fbaa29", size = 6312781 }, + { url = "https://files.pythonhosted.org/packages/a3/f8/286e81a62964ceb6ac10b10925261d4871a762d2a763fbf354115f9afc98/grpcio-1.71.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7c9c80ac6091c916db81131d50926a93ab162a7e97e4428ffc186b6e80d6dda4", size = 5910479 }, + { url = "https://files.pythonhosted.org/packages/35/67/d1febb49ec0f599b9e6d4d0d44c2d4afdbed9c3e80deb7587ec788fcf252/grpcio-1.71.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:789d5e2a3a15419374b7b45cd680b1e83bbc1e52b9086e49308e2c0b5bbae6e3", size = 6013262 }, + { url = "https://files.pythonhosted.org/packages/a1/04/f9ceda11755f0104a075ad7163fc0d96e2e3a9fe25ef38adfc74c5790daf/grpcio-1.71.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:1be857615e26a86d7363e8a163fade914595c81fec962b3d514a4b1e8760467b", size = 6643356 }, + { url = "https://files.pythonhosted.org/packages/fb/ce/236dbc3dc77cf9a9242adcf1f62538734ad64727fabf39e1346ad4bd5c75/grpcio-1.71.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:a76d39b5fafd79ed604c4be0a869ec3581a172a707e2a8d7a4858cb05a5a7637", size = 6186564 }, + { url = "https://files.pythonhosted.org/packages/10/fd/b3348fce9dd4280e221f513dd54024e765b21c348bc475516672da4218e9/grpcio-1.71.0-cp313-cp313-win32.whl", hash = "sha256:74258dce215cb1995083daa17b379a1a5a87d275387b7ffe137f1d5131e2cfbb", size = 3601890 }, + { url = "https://files.pythonhosted.org/packages/be/f8/db5d5f3fc7e296166286c2a397836b8b042f7ad1e11028d82b061701f0f7/grpcio-1.71.0-cp313-cp313-win_amd64.whl", hash = "sha256:22c3bc8d488c039a199f7a003a38cb7635db6656fa96437a8accde8322ce2366", size = 4273308 }, +] + +[[package]] +name = "grpcio-tools" +version = "1.71.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "protobuf" }, + { name = "setuptools" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/05/d2/c0866a48c355a6a4daa1f7e27e210c7fa561b1f3b7c0bce2671e89cfa31e/grpcio_tools-1.71.0.tar.gz", hash = "sha256:38dba8e0d5e0fb23a034e09644fdc6ed862be2371887eee54901999e8f6792a8", size = 5326008 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/f9/60/aa7f261eda558d018457e5c8bd8a8079136e5107a0942fd3167477ab50e2/grpcio_tools-1.71.0-cp310-cp310-linux_armv7l.whl", hash = "sha256:f4ad7f0d756546902597053d70b3af2606fbd70d7972876cd75c1e241d22ae00", size = 2385558 }, + { url = "https://files.pythonhosted.org/packages/0d/e3/e47b96e93e51398ba3462e027d93a10c0c23fffc31733de9bd4f44a2b867/grpcio_tools-1.71.0-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:64bdb291df61cf570b5256777ad5fe2b1db6d67bc46e55dc56a0a862722ae329", size = 5930039 }, + { url = "https://files.pythonhosted.org/packages/a6/69/5d8920002483b2a65ae3b03329dfe3b668c3592f001d5358e1538f540012/grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:8dd9795e982d77a4b496f7278b943c2563d9afde2069cdee78c111a40cc4d675", size = 2351932 }, + { url = "https://files.pythonhosted.org/packages/c4/50/8116e307662a2337cdc3f0e1a8b23af197129448b7ff7e0cf1a76c9b0178/grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c1b5860c41a36b26fec4f52998f1a451d0525a5c9a4fb06b6ea3e9211abdb925", size = 2744962 }, + { url = "https://files.pythonhosted.org/packages/e3/4b/d95be4aaf78d7b02dff3bd332c75c228288178e92af0e5228759ac5002a0/grpcio_tools-1.71.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3059c14035e5dc03d462f261e5900b9a077fd1a36976c3865b8507474520bad4", size = 2476716 }, + { url = "https://files.pythonhosted.org/packages/37/c2/c784a3705b1a1fd277751a8fc881d5a29325a460b9211e3c6164f594b178/grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:f360981b215b1d5aff9235b37e7e1826246e35bbac32a53e41d4e990a37b8f4c", size = 2854132 }, + { url = "https://files.pythonhosted.org/packages/93/8f/173adbf72ed3996e1962182b55abf30151edc8b53daac0bf15cc3dc4b09e/grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bfe3888c3bbe16a5aa39409bc38744a31c0c3d2daa2b0095978c56e106c85b42", size = 3305069 }, + { url = "https://files.pythonhosted.org/packages/e4/a8/b1e7df63e7f83336275922f92ded1cd6918964c511280b31c872c54538f4/grpcio_tools-1.71.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:145985c0bf12131f0a1503e65763e0f060473f7f3928ed1ff3fb0e8aad5bc8ac", size = 2916636 }, + { url = "https://files.pythonhosted.org/packages/be/a3/53f1e74c6e1c92ad94d7a0127a60fe913276a3e8c864737a053a1574b05c/grpcio_tools-1.71.0-cp310-cp310-win32.whl", hash = "sha256:82c430edd939bb863550ee0fecf067d78feff828908a1b529bbe33cc57f2419c", size = 949576 }, + { url = "https://files.pythonhosted.org/packages/97/43/4a3ae830c1405bcb1ba47f2225779dbe9fc009ba341d4a90012919304855/grpcio_tools-1.71.0-cp310-cp310-win_amd64.whl", hash = "sha256:83e90724e3f02415c628e4ead1d6ffe063820aaaa078d9a39176793df958cd5a", size = 1121087 }, + { url = "https://files.pythonhosted.org/packages/5d/ec/73b9797ffec80e1faf039ce3e2f0513e26e1a68eedc525ed294ae2a44d03/grpcio_tools-1.71.0-cp311-cp311-linux_armv7l.whl", hash = "sha256:1f19b16b49afa5d21473f49c0966dd430c88d089cd52ac02404d8cef67134efb", size = 2385557 }, + { url = "https://files.pythonhosted.org/packages/bf/87/42c6e192b7b09c9610a53e771797f7826aee4f6e769683985ae406a2d862/grpcio_tools-1.71.0-cp311-cp311-macosx_10_14_universal2.whl", hash = "sha256:459c8f5e00e390aecd5b89de67deb3ec7188a274bc6cb50e43cef35ab3a3f45d", size = 5954404 }, + { url = "https://files.pythonhosted.org/packages/25/30/3fd385a56d32dce34cde09a64dbaf7cf85d395f2bcd86dd41e4b4ee5938f/grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:edab7e6518de01196be37f96cb1e138c3819986bf5e2a6c9e1519b4d716b2f5a", size = 2352061 }, + { url = "https://files.pythonhosted.org/packages/87/eb/e9971c7693a2d85e7f55760f7906211a95ff74af4d41b05d187849d7fb58/grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8b93b9f6adc7491d4c10144c0643409db298e5e63c997106a804f6f0248dbaf4", size = 2745033 }, + { url = "https://files.pythonhosted.org/packages/15/72/4e69beae87a1b334f80da9e93c8e2f5c8fe4860c956a781246a092dc4c97/grpcio_tools-1.71.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ae5f2efa9e644c10bf1021600bfc099dfbd8e02b184d2d25dc31fcd6c2bc59e", size = 2476743 }, + { url = "https://files.pythonhosted.org/packages/b5/f3/336d2c83f1bfc00a5376bf20dd2273d7aa891b03dd91b11c71ca47392351/grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:65aa082f4435571d65d5ce07fc444f23c3eff4f3e34abef599ef8c9e1f6f360f", size = 2853693 }, + { url = "https://files.pythonhosted.org/packages/62/ba/cc7ace518c11501a4b8620df5edb8188e81470e5b82dc6829212f3e9b2ff/grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:1331e726e08b7bdcbf2075fcf4b47dff07842b04845e6e220a08a4663e232d7f", size = 3304474 }, + { url = "https://files.pythonhosted.org/packages/00/0d/4b843654af3d5aa2f1a5775df1d583e6e3471e6d569106fd3213ad185a98/grpcio_tools-1.71.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:6693a7d3ba138b0e693b3d1f687cdd9db9e68976c3fa2b951c17a072fea8b583", size = 2916147 }, + { url = "https://files.pythonhosted.org/packages/e4/14/047e1c817422bc3d434247b9c640c51fd51ca4e047583ff31d927c3dea73/grpcio_tools-1.71.0-cp311-cp311-win32.whl", hash = "sha256:6d11ed3ff7b6023b5c72a8654975324bb98c1092426ba5b481af406ff559df00", size = 949374 }, + { url = "https://files.pythonhosted.org/packages/86/cb/739a1b6d517672693796022c0f9061f63eaa243ec70cbbfa59bf881ed9fb/grpcio_tools-1.71.0-cp311-cp311-win_amd64.whl", hash = "sha256:072b2a5805ac97e4623b3aa8f7818275f3fb087f4aa131b0fce00471065f6eaa", size = 1120786 }, + { url = "https://files.pythonhosted.org/packages/de/e4/156956b92ad0298290c3d68e6670bc5a6fbefcccfe1ec3997480605e7135/grpcio_tools-1.71.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:61c0409d5bdac57a7bd0ce0ab01c1c916728fe4c8a03d77a25135ad481eb505c", size = 2385480 }, + { url = "https://files.pythonhosted.org/packages/c1/08/9930eb4bb38c5214041c9f24f8b35e9864a7938282db986836546c782d52/grpcio_tools-1.71.0-cp312-cp312-macosx_10_14_universal2.whl", hash = "sha256:28784f39921d061d2164a9dcda5164a69d07bf29f91f0ea50b505958292312c9", size = 5951891 }, + { url = "https://files.pythonhosted.org/packages/73/65/931f29ec9c33719d48e1e30446ecce6f5d2cd4e4934fa73fbe07de41c43b/grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:192808cf553cedca73f0479cc61d5684ad61f24db7a5f3c4dfe1500342425866", size = 2351967 }, + { url = "https://files.pythonhosted.org/packages/b8/26/2ec8748534406214f20a4809c36efcfa88d1a26246e8312102e3ef8c295d/grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:989ee9da61098230d3d4c8f8f8e27c2de796f1ff21b1c90110e636d9acd9432b", size = 2745003 }, + { url = "https://files.pythonhosted.org/packages/f1/33/87b4610c86a4e10ee446b543a4d536f94ab04f828bab841f0bc1a083de72/grpcio_tools-1.71.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:541a756276c8a55dec991f6c0106ae20c8c8f5ce8d0bdbfcb01e2338d1a8192b", size = 2476455 }, + { url = "https://files.pythonhosted.org/packages/00/7c/f7f0cc36a43be9d45b3ce2a55245f3c7d063a24b7930dd719929e58871a4/grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:870c0097700d13c403e5517cb7750ab5b4a791ce3e71791c411a38c5468b64bd", size = 2854333 }, + { url = "https://files.pythonhosted.org/packages/07/c4/34b9ea62b173c13fa7accba5f219355b320c05c80c79c3ba70fe52f47b2f/grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:abd57f615e88bf93c3c6fd31f923106e3beb12f8cd2df95b0d256fa07a7a0a57", size = 3304297 }, + { url = "https://files.pythonhosted.org/packages/5c/ef/9d3449db8a07688dc3de7dcbd2a07048a128610b1a491c5c0cb3e90a00c5/grpcio_tools-1.71.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:753270e2d06d37e6d7af8967d1d059ec635ad215882041a36294f4e2fd502b2e", size = 2916212 }, + { url = "https://files.pythonhosted.org/packages/2e/c6/990e8194c934dfe7cf89ef307c319fa4f2bc0b78aeca707addbfa1e502f1/grpcio_tools-1.71.0-cp312-cp312-win32.whl", hash = "sha256:0e647794bd7138b8c215e86277a9711a95cf6a03ff6f9e555d54fdf7378b9f9d", size = 948849 }, + { url = "https://files.pythonhosted.org/packages/42/95/3c36d3205e6bd19853cc2420e44b6ef302eb4cfcf56498973c7e85f6c03b/grpcio_tools-1.71.0-cp312-cp312-win_amd64.whl", hash = "sha256:48debc879570972d28bfe98e4970eff25bb26da3f383e0e49829b2d2cd35ad87", size = 1120294 }, + { url = "https://files.pythonhosted.org/packages/84/a7/70dc7e9957bcbaccd4dcb6cc11215e0b918f546d55599221522fe0d073e0/grpcio_tools-1.71.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:9a78d07d6c301a25ef5ede962920a522556a1dfee1ccc05795994ceb867f766c", size = 2384758 }, + { url = "https://files.pythonhosted.org/packages/65/79/57320b28d0a0c5ec94095fd571a65292f8ed7e1c47e59ae4021e8a48d49b/grpcio_tools-1.71.0-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:580ac88141c9815557e63c9c04f5b1cdb19b4db8d0cb792b573354bde1ee8b12", size = 5951661 }, + { url = "https://files.pythonhosted.org/packages/80/3d/343df5ed7c5dd66fc7a19e4ef3e97ccc4f5d802122b04cd6492f0dcd79f5/grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:f7c678e68ece0ae908ecae1c4314a0c2c7f83e26e281738b9609860cc2c82d96", size = 2351571 }, + { url = "https://files.pythonhosted.org/packages/56/2f/b9736e8c84e880c4237f5b880c6c799b4977c5cde190999bc7ab4b2ec445/grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56ecd6cc89b5e5eed1de5eb9cafce86c9c9043ee3840888cc464d16200290b53", size = 2744580 }, + { url = "https://files.pythonhosted.org/packages/76/9b/bdb384967353da7bf64bac4232f4cf8ae43f19d0f2f640978d4d4197e667/grpcio_tools-1.71.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e52a041afc20ab2431d756b6295d727bd7adee813b21b06a3483f4a7a15ea15f", size = 2475978 }, + { url = "https://files.pythonhosted.org/packages/26/71/1411487fd7862d347b98fda5e3beef611a71b2ac2faac62a965d9e2536b3/grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:2a1712f12102b60c8d92779b89d0504e0d6f3a59f2b933e5622b8583f5c02992", size = 2853314 }, + { url = "https://files.pythonhosted.org/packages/03/06/59d0523eb1ba2f64edc72cb150152fa1b2e77061cae3ef3ecd3ef2a87f51/grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_i686.whl", hash = "sha256:41878cb7a75477e62fdd45e7e9155b3af1b7a5332844021e2511deaf99ac9e6c", size = 3303981 }, + { url = "https://files.pythonhosted.org/packages/c2/71/fb9fb49f2b738ec1dfbbc8cdce0b26e5f9c5fc0edef72e453580620d6a36/grpcio_tools-1.71.0-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:682e958b476049ccc14c71bedf3f979bced01f6e0c04852efc5887841a32ad6b", size = 2915876 }, + { url = "https://files.pythonhosted.org/packages/bd/0f/0d49f6fe6fa2d09e9820dd9eeb30437e86002303076be2b6ada0fb52b8f2/grpcio_tools-1.71.0-cp313-cp313-win32.whl", hash = "sha256:0ccfb837152b7b858b9f26bb110b3ae8c46675d56130f6c2f03605c4f129be13", size = 948245 }, + { url = "https://files.pythonhosted.org/packages/bb/14/ab131a39187bfea950280b2277a82d2033469fe8c86f73b10b19f53cc5ca/grpcio_tools-1.71.0-cp313-cp313-win_amd64.whl", hash = "sha256:ffff9bc5eacb34dd26b487194f7d44a3e64e752fc2cf049d798021bf25053b87", size = 1119649 }, +] + [[package]] name = "h11" version = "0.14.0" @@ -802,6 +905,28 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/95/04/ff642e65ad6b90db43e668d70ffb6736436c7ce41fcc549f4e9472234127/h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761", size = 58259 }, ] +[[package]] +name = "h2" +version = "4.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1b/38/d7f80fd13e6582fb8e0df8c9a653dcc02b03ca34f4d72f34869298c5baf8/h2-4.2.0.tar.gz", hash = "sha256:c8a52129695e88b1a0578d8d2cc6842bbd79128ac685463b887ee278126ad01f", size = 2150682 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d0/9e/984486f2d0a0bd2b024bf4bc1c62688fcafa9e61991f041fb0e2def4a982/h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0", size = 60957 }, +] + +[[package]] +name = "hpack" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357 }, +] + [[package]] name = "httpcore" version = "1.0.7" @@ -830,6 +955,11 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/39/e50c7c3a983047577ee07d2a9e53faf5a69493943ec3f6a384bdc792deb2/httpx-0.28.1-py3-none-any.whl", hash = "sha256:d909fcccc110f8c7faf814ca82a9a4d816bc5a6dbfea25d6591d6985b8ba59ad", size = 73517 }, ] +[package.optional-dependencies] +http2 = [ + { name = "h2" }, +] + [[package]] name = "httpx-sse" version = "0.4.0" @@ -857,6 +987,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/4d/8092df2cb0cafa9fcaf691db851b2fccfe9cad4048e081436bbbdf56e4e1/huggingface_hub-0.29.0-py3-none-any.whl", hash = "sha256:c02daa0b6bafbdacb1320fdfd1dc7151d0940825c88c4ef89837fdb1f6ea0afe", size = 468012 }, ] +[[package]] +name = "hyperframe" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007 }, +] + [[package]] name = "identify" version = "2.6.7" @@ -1250,6 +1389,7 @@ unit = [ { name = "chardet" }, { name = "openai" }, { name = "pypdf" }, + { name = "qdrant-client" }, { name = "sqlite-vec" }, ] @@ -1290,6 +1430,7 @@ requires-dist = [ { name = "pytest-cov", marker = "extra == 'dev'" }, { name = "pytest-html", marker = "extra == 'dev'" }, { name = "python-dotenv" }, + { name = "qdrant-client", marker = "extra == 'unit'" }, { name = "requests" }, { name = "rich" }, { name = "rich", marker = "extra == 'codegen'" }, @@ -1314,7 +1455,6 @@ requires-dist = [ { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "uvicorn", marker = "extra == 'dev'" }, ] -provides-extras = ["dev", "unit", "test", "docs", "codegen"] [[package]] name = "llama-stack-client" @@ -2062,6 +2202,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/88/5f/e351af9a41f866ac3f1fac4ca0613908d9a41741cfcf2228f4ad853b697d/pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669", size = 20556 }, ] +[[package]] +name = "portalocker" +version = "2.10.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pywin32", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/fb/a70a4214956182e0d7a9099ab17d50bfcba1056188e9b14f35b9e2b62a0d/portalocker-2.10.1-py3-none-any.whl", hash = "sha256:53a5984ebc86a025552264b459b46a2086e269b21823cb572f8f28ee759e45bf", size = 18423 }, +] + [[package]] name = "pre-commit" version = "4.1.0" @@ -2668,6 +2820,24 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e3/fe/72e7e166bda3885810bee7b23049133e142f7c80c295bae02c562caeea16/pyzmq-26.2.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:bd8fdee945b877aa3bffc6a5a8816deb048dab0544f9df3731ecd0e54d8c84c9", size = 556563 }, ] +[[package]] +name = "qdrant-client" +version = "1.13.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "grpcio" }, + { name = "grpcio-tools" }, + { name = "httpx", extra = ["http2"] }, + { name = "numpy" }, + { name = "portalocker" }, + { name = "pydantic" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/eb/58/1e4acd7ff7637ed56a66e5044699e7af6067232703d0b34f05068fc6234b/qdrant_client-1.13.3.tar.gz", hash = "sha256:61ca09e07c6d7ac0dfbdeb13dca4fe5f3e08fa430cb0d74d66ef5d023a70adfc", size = 266278 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dd/b4/bd676f91f5234ab59282e4a110f324029684482cbe08e7a1c77b6338013b/qdrant_client-1.13.3-py3-none-any.whl", hash = "sha256:f52cacbb936e547d3fceb1aaed3e3c56be0ebfd48e8ea495ea3dbc89c671d1d2", size = 306674 }, +] + [[package]] name = "rapidfuzz" version = "3.12.2" @@ -3417,7 +3587,8 @@ source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "python_full_version < '3.11' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and sys_platform == 'darwin'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version >= '3.13' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", ] dependencies = [ { name = "filelock", marker = "sys_platform == 'darwin'" }, @@ -3444,8 +3615,10 @@ resolution-markers = [ "python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'", "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')", + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", ] dependencies = [ { name = "filelock", marker = "sys_platform != 'darwin'" }, @@ -3482,8 +3655,10 @@ resolution-markers = [ "python_full_version < '3.11' and sys_platform == 'darwin'", "python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", "python_full_version == '3.11.*' and sys_platform == 'darwin'", - "python_full_version >= '3.12' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version >= '3.12' and sys_platform == 'darwin'", + "python_full_version >= '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version == '3.12.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version >= '3.13' and sys_platform == 'darwin'", + "python_full_version == '3.12.*' and sys_platform == 'darwin'", ] dependencies = [ { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, @@ -3509,7 +3684,8 @@ source = { registry = "https://download.pytorch.org/whl/cpu" } resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", "(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version >= '3.12' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version == '3.12.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.12.*' and sys_platform != 'darwin' and sys_platform != 'linux')", ] dependencies = [ { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, From c029fbcd13ff270888f3e34e5369fb9d750821d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 18 Mar 2025 22:06:53 +0100 Subject: [PATCH 05/19] fix: return 4xx for non-existent resources in GET requests (#1635) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Removed Optional return types for GET methods - Raised ValueError when requested resource is not found - Ensures proper 4xx response for missing resources - Updated the API generator to check for wrong signatures ``` $ uv run --with ".[dev]" ./docs/openapi_generator/run_openapi_generator.sh Validating API method return types... API Method Return Type Validation Errors: Method ScoringFunctions.get_scoring_function returns Optional type ``` Closes: https://github.com/meta-llama/llama-stack/issues/1630 ## Test Plan Run the server then: ``` curl http://127.0.0.1:8321/v1/models/foo {"detail":"Invalid value: Model 'foo' not found"}% ``` Server log: ``` INFO: 127.0.0.1:52307 - "GET /v1/models/foo HTTP/1.1" 400 Bad Request 09:51:42.654 [END] /v1/models/foo [StatusCode.OK] (134.65ms) 09:51:42.651 [ERROR] Error executing endpoint route='/v1/models/{model_id:path}' method='get' Traceback (most recent call last): File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 193, in endpoint return await maybe_await(value) File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/server/server.py", line 156, in maybe_await return await value File "/Users/leseb/Documents/AI/llama-stack/llama_stack/providers/utils/telemetry/trace_protocol.py", line 102, in async_wrapper result = await method(self, *args, **kwargs) File "/Users/leseb/Documents/AI/llama-stack/llama_stack/distribution/routers/routing_tables.py", line 217, in get_model raise ValueError(f"Model '{model_id}' not found") ValueError: Model 'foo' not found ``` Signed-off-by: Sébastien Han --- docs/_static/llama-stack-spec.html | 90 +++---------------- docs/_static/llama-stack-spec.yaml | 40 +++------ docs/openapi_generator/generate.py | 12 ++- docs/openapi_generator/pyopenapi/utility.py | 39 +++++++- llama_stack/apis/benchmarks/benchmarks.py | 2 +- llama_stack/apis/datasets/datasets.py | 2 +- llama_stack/apis/eval/eval.py | 2 +- llama_stack/apis/files/files.py | 2 +- llama_stack/apis/models/models.py | 2 +- .../apis/post_training/post_training.py | 4 +- .../scoring_functions/scoring_functions.py | 2 +- llama_stack/apis/shields/shields.py | 2 +- llama_stack/apis/vector_dbs/vector_dbs.py | 2 +- .../distribution/routers/routing_tables.py | 47 +++++++--- 14 files changed, 112 insertions(+), 136 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 210a84b03..72b2e6b17 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -1101,14 +1101,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Benchmark" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Benchmark" } } } @@ -1150,14 +1143,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Dataset" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Dataset" } } } @@ -1232,14 +1218,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Model" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Model" } } } @@ -1314,14 +1293,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/ScoringFn" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/ScoringFn" } } } @@ -1363,14 +1335,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/Shield" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/Shield" } } } @@ -1673,14 +1638,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/PostTrainingJobArtifactsResponse" } } } @@ -1722,14 +1680,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/PostTrainingJobStatusResponse" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/PostTrainingJobStatusResponse" } } } @@ -1804,14 +1755,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/FileUploadResponse" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/FileUploadResponse" } } } @@ -1913,14 +1857,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/VectorDB" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/VectorDB" } } } @@ -2246,14 +2183,7 @@ "content": { "application/json": { "schema": { - "oneOf": [ - { - "$ref": "#/components/schemas/JobStatus" - }, - { - "type": "null" - } - ] + "$ref": "#/components/schemas/JobStatus" } } } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index a1eb07444..6f4a9528b 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -757,9 +757,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Benchmark' - - type: 'null' + $ref: '#/components/schemas/Benchmark' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -787,9 +785,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Dataset' - - type: 'null' + $ref: '#/components/schemas/Dataset' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -840,9 +836,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Model' - - type: 'null' + $ref: '#/components/schemas/Model' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -893,9 +887,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/ScoringFn' - - type: 'null' + $ref: '#/components/schemas/ScoringFn' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -923,9 +915,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/Shield' - - type: 'null' + $ref: '#/components/schemas/Shield' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1127,9 +1117,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' - - type: 'null' + $ref: '#/components/schemas/PostTrainingJobArtifactsResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1157,9 +1145,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/PostTrainingJobStatusResponse' - - type: 'null' + $ref: '#/components/schemas/PostTrainingJobStatusResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1210,9 +1196,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/FileUploadResponse' - - type: 'null' + $ref: '#/components/schemas/FileUploadResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1281,9 +1265,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/VectorDB' - - type: 'null' + $ref: '#/components/schemas/VectorDB' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1509,9 +1491,7 @@ paths: content: application/json: schema: - oneOf: - - $ref: '#/components/schemas/JobStatus' - - type: 'null' + $ref: '#/components/schemas/JobStatus' '400': $ref: '#/components/responses/BadRequest400' '429': diff --git a/docs/openapi_generator/generate.py b/docs/openapi_generator/generate.py index a2553f905..879ac95e2 100644 --- a/docs/openapi_generator/generate.py +++ b/docs/openapi_generator/generate.py @@ -12,7 +12,7 @@ from datetime import datetime from pathlib import Path - +import sys import fire import ruamel.yaml as yaml @@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402 from .pyopenapi.options import Options # noqa: E402 from .pyopenapi.specification import Info, Server # noqa: E402 -from .pyopenapi.utility import Specification # noqa: E402 +from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402 def str_presenter(dumper, data): @@ -39,6 +39,14 @@ def main(output_dir: str): if not output_dir.exists(): raise ValueError(f"Directory {output_dir} does not exist") + # Validate API protocols before generating spec + print("Validating API method return types...") + return_type_errors = validate_api_method_return_types() + if return_type_errors: + print("\nAPI Method Return Type Validation Errors:\n") + for error in return_type_errors: + print(error) + sys.exit(1) now = str(datetime.now()) print( "Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now diff --git a/docs/openapi_generator/pyopenapi/utility.py b/docs/openapi_generator/pyopenapi/utility.py index f134aab4b..f60a33bb7 100644 --- a/docs/openapi_generator/pyopenapi/utility.py +++ b/docs/openapi_generator/pyopenapi/utility.py @@ -6,16 +6,19 @@ import json import typing +import inspect +import os from pathlib import Path from typing import TextIO +from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args from llama_stack.strong_typing.schema import object_to_json, StrictJsonType +from llama_stack.distribution.resolver import api_protocol_map from .generator import Generator from .options import Options from .specification import Document - THIS_DIR = Path(__file__).parent @@ -114,3 +117,37 @@ class Specification: ) f.write(html) + +def is_optional_type(type_: Any) -> bool: + """Check if a type is Optional.""" + origin = get_origin(type_) + args = get_args(type_) + return origin is Optional or (origin is Union and type(None) in args) + + +def validate_api_method_return_types() -> List[str]: + """Validate that all API methods have proper return types.""" + errors = [] + protocols = api_protocol_map() + + for protocol_name, protocol in protocols.items(): + methods = inspect.getmembers(protocol, predicate=inspect.isfunction) + + for method_name, method in methods: + if not hasattr(method, '__webmethod__'): + continue + + # Only check GET methods + if method.__webmethod__.method != "GET": + continue + + hints = get_type_hints(method) + + if 'return' not in hints: + errors.append(f"Method {protocol_name}.{method_name} has no return type annotation") + else: + return_type = hints['return'] + if is_optional_type(return_type): + errors.append(f"Method {protocol_name}.{method_name} returns Optional type") + + return errors diff --git a/llama_stack/apis/benchmarks/benchmarks.py b/llama_stack/apis/benchmarks/benchmarks.py index 39ba355e9..809af8868 100644 --- a/llama_stack/apis/benchmarks/benchmarks.py +++ b/llama_stack/apis/benchmarks/benchmarks.py @@ -52,7 +52,7 @@ class Benchmarks(Protocol): async def get_benchmark( self, benchmark_id: str, - ) -> Optional[Benchmark]: ... + ) -> Benchmark: ... @webmethod(route="/eval/benchmarks", method="POST") async def register_benchmark( diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index d033d0b70..616371c7d 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -201,7 +201,7 @@ class Datasets(Protocol): async def get_dataset( self, dataset_id: str, - ) -> Optional[Dataset]: ... + ) -> Dataset: ... @webmethod(route="/datasets", method="GET") async def list_datasets(self) -> ListDatasetsResponse: ... diff --git a/llama_stack/apis/eval/eval.py b/llama_stack/apis/eval/eval.py index dec018d83..51c38b16a 100644 --- a/llama_stack/apis/eval/eval.py +++ b/llama_stack/apis/eval/eval.py @@ -117,7 +117,7 @@ class Eval(Protocol): """ @webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET") - async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: + async def job_status(self, benchmark_id: str, job_id: str) -> JobStatus: """Get the status of a job. :param benchmark_id: The ID of the benchmark to run the evaluation on. diff --git a/llama_stack/apis/files/files.py b/llama_stack/apis/files/files.py index f17fadc8c..65c1ead6a 100644 --- a/llama_stack/apis/files/files.py +++ b/llama_stack/apis/files/files.py @@ -115,7 +115,7 @@ class Files(Protocol): async def get_upload_session_info( self, upload_id: str, - ) -> Optional[FileUploadResponse]: + ) -> FileUploadResponse: """ Returns information about an existsing upload session diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 64b9510ea..893ebc179 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -66,7 +66,7 @@ class Models(Protocol): async def get_model( self, model_id: str, - ) -> Optional[Model]: ... + ) -> Model: ... @webmethod(route="/models", method="POST") async def register_model( diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index ed15c6de4..636eb7e7b 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -202,10 +202,10 @@ class PostTraining(Protocol): async def get_training_jobs(self) -> ListPostTrainingJobsResponse: ... @webmethod(route="/post-training/job/status", method="GET") - async def get_training_job_status(self, job_uuid: str) -> Optional[PostTrainingJobStatusResponse]: ... + async def get_training_job_status(self, job_uuid: str) -> PostTrainingJobStatusResponse: ... @webmethod(route="/post-training/job/cancel", method="POST") async def cancel_training_job(self, job_uuid: str) -> None: ... @webmethod(route="/post-training/job/artifacts", method="GET") - async def get_training_job_artifacts(self, job_uuid: str) -> Optional[PostTrainingJobArtifactsResponse]: ... + async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse: ... diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 52508d2ec..b02a7a0c4 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -135,7 +135,7 @@ class ScoringFunctions(Protocol): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") - async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... + async def get_scoring_function(self, scoring_fn_id: str, /) -> ScoringFn: ... @webmethod(route="/scoring-functions", method="POST") async def register_scoring_function( diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index ec1179ac4..67f3bd27b 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -49,7 +49,7 @@ class Shields(Protocol): async def list_shields(self) -> ListShieldsResponse: ... @webmethod(route="/shields/{identifier:path}", method="GET") - async def get_shield(self, identifier: str) -> Optional[Shield]: ... + async def get_shield(self, identifier: str) -> Shield: ... @webmethod(route="/shields", method="POST") async def register_shield( diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 9a4aa322f..fe6c33919 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -50,7 +50,7 @@ class VectorDBs(Protocol): async def get_vector_db( self, vector_db_id: str, - ) -> Optional[VectorDB]: ... + ) -> VectorDB: ... @webmethod(route="/vector-dbs", method="POST") async def register_vector_db( diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 533993421..5dea942f7 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -219,8 +219,11 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def list_models(self) -> ListModelsResponse: return ListModelsResponse(data=await self.get_all_with_type("model")) - async def get_model(self, model_id: str) -> Optional[Model]: - return await self.get_object_by_identifier("model", model_id) + async def get_model(self, model_id: str) -> Model: + model = await self.get_object_by_identifier("model", model_id) + if model is None: + raise ValueError(f"Model '{model_id}' not found") + return model async def register_model( self, @@ -267,8 +270,11 @@ class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): async def list_shields(self) -> ListShieldsResponse: return ListShieldsResponse(data=await self.get_all_with_type(ResourceType.shield.value)) - async def get_shield(self, identifier: str) -> Optional[Shield]: - return await self.get_object_by_identifier("shield", identifier) + async def get_shield(self, identifier: str) -> Shield: + shield = await self.get_object_by_identifier("shield", identifier) + if shield is None: + raise ValueError(f"Shield '{identifier}' not found") + return shield async def register_shield( self, @@ -303,8 +309,11 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): async def list_vector_dbs(self) -> ListVectorDBsResponse: return ListVectorDBsResponse(data=await self.get_all_with_type("vector_db")) - async def get_vector_db(self, vector_db_id: str) -> Optional[VectorDB]: - return await self.get_object_by_identifier("vector_db", vector_db_id) + async def get_vector_db(self, vector_db_id: str) -> VectorDB: + vector_db = await self.get_object_by_identifier("vector_db", vector_db_id) + if vector_db is None: + raise ValueError(f"Vector DB '{vector_db_id}' not found") + return vector_db async def register_vector_db( self, @@ -355,8 +364,11 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): async def list_datasets(self) -> ListDatasetsResponse: return ListDatasetsResponse(data=await self.get_all_with_type(ResourceType.dataset.value)) - async def get_dataset(self, dataset_id: str) -> Optional[Dataset]: - return await self.get_object_by_identifier("dataset", dataset_id) + async def get_dataset(self, dataset_id: str) -> Dataset: + dataset = await self.get_object_by_identifier("dataset", dataset_id) + if dataset is None: + raise ValueError(f"Dataset '{dataset_id}' not found") + return dataset async def register_dataset( self, @@ -408,8 +420,11 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): async def list_scoring_functions(self) -> ListScoringFunctionsResponse: return ListScoringFunctionsResponse(data=await self.get_all_with_type(ResourceType.scoring_function.value)) - async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: - return await self.get_object_by_identifier("scoring_function", scoring_fn_id) + async def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: + scoring_fn = await self.get_object_by_identifier("scoring_function", scoring_fn_id) + if scoring_fn is None: + raise ValueError(f"Scoring function '{scoring_fn_id}' not found") + return scoring_fn async def register_scoring_function( self, @@ -445,8 +460,11 @@ class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): async def list_benchmarks(self) -> ListBenchmarksResponse: return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark")) - async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]: - return await self.get_object_by_identifier("benchmark", benchmark_id) + async def get_benchmark(self, benchmark_id: str) -> Benchmark: + benchmark = await self.get_object_by_identifier("benchmark", benchmark_id) + if benchmark is None: + raise ValueError(f"Benchmark '{benchmark_id}' not found") + return benchmark async def register_benchmark( self, @@ -490,7 +508,10 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): return ListToolGroupsResponse(data=await self.get_all_with_type("tool_group")) async def get_tool_group(self, toolgroup_id: str) -> ToolGroup: - return await self.get_object_by_identifier("tool_group", toolgroup_id) + tool_group = await self.get_object_by_identifier("tool_group", toolgroup_id) + if tool_group is None: + raise ValueError(f"Tool group '{toolgroup_id}' not found") + return tool_group async def get_tool(self, tool_name: str) -> Tool: return await self.get_object_by_identifier("tool", tool_name) From d609ffce2adef44c9448691085ea9fb9fe57a9b1 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 18 Mar 2025 17:12:17 -0400 Subject: [PATCH 06/19] chore: Add links and badges to both unit and integration tests (#1632) # What does this PR do? This makes it easier to know the statuses of both and identifying failed builds. Signed-off-by: Yuan Tang --- .github/workflows/integration-tests.yml | 2 +- README.md | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 86adf8a15..0af46e1f0 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -1,4 +1,4 @@ -name: Integration tests +name: Integration Tests on: push: diff --git a/README.md b/README.md index aade9c15f..d2adc3376 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,8 @@ [![PyPI - Downloads](https://img.shields.io/pypi/dm/llama-stack)](https://pypi.org/project/llama-stack/) [![License](https://img.shields.io/pypi/l/llama_stack.svg)](https://github.com/meta-llama/llama-stack/blob/main/LICENSE) [![Discord](https://img.shields.io/discord/1257833999603335178)](https://discord.gg/llama-stack) -![Unit](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main) +[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) +[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) From 5ece26297642d99e183c050be153707e8f9828ca Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande <60317842+cheesecake100201@users.noreply.github.com> Date: Wed, 19 Mar 2025 02:43:46 +0530 Subject: [PATCH 07/19] chore: Make code interpreter async (#1654) # What does this PR do? Made code interpreter tool call to be async such that its non blocking ## Test Plan pytest -s -v tests/integration/agents/test_agents.py --stack-config=together --text-model=meta-llama/Llama-3.3-70B-Instruct image [//]: # (## Documentation) Co-authored-by: sarthakdeshpande --- .../inline/tool_runtime/code_interpreter/code_interpreter.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py index 4b97914c5..9610b9b46 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_interpreter.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +import asyncio import logging import os import tempfile @@ -37,7 +38,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): async def initialize(self): pass - async def register_tool(self, tool: Tool): + async def register_tool(self, tool: Tool) -> None: pass async def unregister_tool(self, tool_id: str) -> None: @@ -65,7 +66,7 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime): # Use environment variable to control bwrap usage force_disable_bwrap = os.environ.get("DISABLE_CODE_SANDBOX", "").lower() in ("1", "true", "yes") req = CodeExecutionRequest(scripts=[script], use_bwrap=not force_disable_bwrap) - res = self.code_executor.execute(req) + res = await asyncio.to_thread(self.code_executor.execute, req) pieces = [res["process_status"]] for out_type in ["stdout", "stderr"]: res_out = res[out_type] From 22e560351e5b74ffcb3bfd5a30a9b4772c130b57 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 18 Mar 2025 17:39:22 -0400 Subject: [PATCH 08/19] ci: Add scheduled workflow to update changelog (#1503) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This is a follow up from https://github.com/meta-llama/llama-stack/pull/1463. cc @yanxi0830 --------- Signed-off-by: Yuan Tang Co-authored-by: Sébastien Han --- .github/workflows/changelog.yml | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/workflows/changelog.yml diff --git a/.github/workflows/changelog.yml b/.github/workflows/changelog.yml new file mode 100644 index 000000000..5b63e231c --- /dev/null +++ b/.github/workflows/changelog.yml @@ -0,0 +1,29 @@ +name: Update Changelog + +on: + release: + types: [published, unpublished, created, edited, deleted, released] + +permissions: + contents: read + +jobs: + generate_changelog: + name: Generate changelog + permissions: + contents: write # for peter-evans/create-pull-request to create branch + pull-requests: write # for peter-evans/create-pull-request to create a PR + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + ref: main + fetch-depth: 0 + - run: | + python ./scripts/gen-changelog.py + - uses: peter-evans/create-pull-request@v7 + with: + title: 'docs: update CHANGELOG.md for ${{ github.ref_name }}' + commit-message: 'docs: update CHANGELOG.md for ${{ github.ref_name }}' + branch: create-pull-request/changelog + signoff: true From f86f3cf8783e8923f9c67658d06187a6535e842f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 18 Mar 2025 22:52:21 +0100 Subject: [PATCH 09/19] docs: remove redundant installation instructions (#1138) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The previous installation instructions were mostly duplicating information already covered in the documentation, either in the “Start a Server” or “Contributing Guide” sections. Removed these redundant details to avoid confusion and streamline the setup process. Signed-off-by: Sébastien Han Signed-off-by: Sébastien Han --- README.md | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/README.md b/README.md index d2adc3376..918433d51 100644 --- a/README.md +++ b/README.md @@ -73,26 +73,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider | Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) | | vLLM | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) | -### Installation - -You have two ways to install this repository: - -* **Install as a package**: - You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: - ```bash - pip install llama-stack - ``` - -* **Install from source**: - If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv). - Then, run the following commands: - ```bash - git clone git@github.com:meta-llama/llama-stack.git - cd llama-stack - - uv sync - uv pip install -e . - ``` ### Documentation From 0cbb7f7f21982bf943a257009bd916dbfe510122 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 18 Mar 2025 17:58:16 -0400 Subject: [PATCH 10/19] chore: fix mypy violations in post_training modules (#1548) # What does this PR do? Fixes a bunch of violations. Note: this patch touches all files but post_training.py that will be significantly changed by #1437, hence leaving it out of the picture for now. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Testing with https://github.com/meta-llama/llama-stack/pull/1543 Also checked that GPU training works with the change: ``` INFO: ::1:53316 - "POST /v1/post-training/supervised-fine-tune HTTP/1.1" 200 OK INFO: ::1:53316 - "GET /v1/post-training/job/status?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK INFO: ::1:53316 - "GET /v1/post-training/job/artifacts?job_uuid=test-jobb5ca2d84-d541-42f8-883b-762828b4c0e7 HTTP/1.1" 200 OK 21:24:01.161 [END] /v1/post-training/supervised-fine-tune [StatusCode.OK] (32526.75ms) 21:23:28.769 [DEBUG] Setting manual seed to local seed 3918872849. Local seed is seed + rank = 3918872849 + 0 21:23:28.996 [INFO] Identified model_type = Llama3_2. Ignoring output.weight in checkpoint in favor of the tok_embedding.weight tied weights. 21:23:29.933 [INFO] Memory stats after model init: GPU peak memory allocation: 6.05 GiB GPU peak memory reserved: 6.10 GiB GPU peak memory active: 6.05 GiB 21:23:29.934 [INFO] Model is initialized with precision torch.bfloat16. 21:23:30.115 [INFO] Tokenizer is initialized. 21:23:30.118 [INFO] Optimizer is initialized. 21:23:30.119 [INFO] Loss is initialized. 21:23:30.896 [INFO] Dataset and Sampler are initialized. 21:23:30.898 [INFO] Learning rate scheduler is initialized. 21:23:31.618 [INFO] Memory stats after model init: GPU peak memory allocation: 6.24 GiB GPU peak memory reserved: 6.30 GiB GPU peak memory active: 6.24 GiB 21:23:31.620 [INFO] Starting checkpoint save... 21:23:59.428 [INFO] Model checkpoint of size 6.43 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/consolidated.00.pth 21:23:59.445 [INFO] Adapter checkpoint of size 0.00 GB saved to /home/ec2-user/.llama/checkpoints/meta-llama/Llama-3.2-3B-Instruct-sft-0/adapter/adapter.pth ``` [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka --- docs/_static/llama-stack-spec.html | 26 ++++------- docs/_static/llama-stack-spec.yaml | 13 ++---- .../apis/post_training/post_training.py | 6 +-- .../inline/post_training/common/validator.py | 8 +++- .../torchtune/common/checkpointer.py | 8 ++-- .../post_training/torchtune/common/utils.py | 13 +++--- .../post_training/torchtune/datasets/sft.py | 2 +- .../recipes/lora_finetuning_single_device.py | 45 ++++++++++--------- pyproject.toml | 4 -- 9 files changed, 56 insertions(+), 69 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 72b2e6b17..2362dfa53 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -9847,23 +9847,6 @@ ], "title": "ScoreBatchResponse" }, - "AlgorithmConfig": { - "oneOf": [ - { - "$ref": "#/components/schemas/LoraFinetuningConfig" - }, - { - "$ref": "#/components/schemas/QATFinetuningConfig" - } - ], - "discriminator": { - "propertyName": "type", - "mapping": { - "LoRA": "#/components/schemas/LoraFinetuningConfig", - "QAT": "#/components/schemas/QATFinetuningConfig" - } - } - }, "LoraFinetuningConfig": { "type": "object", "properties": { @@ -9999,7 +9982,14 @@ "type": "string" }, "algorithm_config": { - "$ref": "#/components/schemas/AlgorithmConfig" + "oneOf": [ + { + "$ref": "#/components/schemas/LoraFinetuningConfig" + }, + { + "$ref": "#/components/schemas/QATFinetuningConfig" + } + ] } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 6f4a9528b..38e08e41c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -6678,15 +6678,6 @@ components: required: - results title: ScoreBatchResponse - AlgorithmConfig: - oneOf: - - $ref: '#/components/schemas/LoraFinetuningConfig' - - $ref: '#/components/schemas/QATFinetuningConfig' - discriminator: - propertyName: type - mapping: - LoRA: '#/components/schemas/LoraFinetuningConfig' - QAT: '#/components/schemas/QATFinetuningConfig' LoraFinetuningConfig: type: object properties: @@ -6770,7 +6761,9 @@ components: checkpoint_dir: type: string algorithm_config: - $ref: '#/components/schemas/AlgorithmConfig' + oneOf: + - $ref: '#/components/schemas/LoraFinetuningConfig' + - $ref: '#/components/schemas/QATFinetuningConfig' additionalProperties: false required: - job_uuid diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index 636eb7e7b..362f87a26 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -6,7 +6,7 @@ from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Protocol, Union +from typing import Any, Dict, List, Literal, Optional, Protocol from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -89,7 +89,7 @@ class QATFinetuningConfig(BaseModel): AlgorithmConfig = register_schema( - Annotated[Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")], + Annotated[LoraFinetuningConfig | QATFinetuningConfig, Field(discriminator="type")], name="AlgorithmConfig", ) @@ -184,7 +184,7 @@ class PostTraining(Protocol): description="Model descriptor from `llama model list`", ), checkpoint_dir: Optional[str] = None, - algorithm_config: Optional[AlgorithmConfig] = None, + algorithm_config: Optional[LoraFinetuningConfig | QATFinetuningConfig] = None, ) -> PostTrainingJob: ... @webmethod(route="/post-training/preference-optimize", method="POST") diff --git a/llama_stack/providers/inline/post_training/common/validator.py b/llama_stack/providers/inline/post_training/common/validator.py index e76edf3a0..b0aec6187 100644 --- a/llama_stack/providers/inline/post_training/common/validator.py +++ b/llama_stack/providers/inline/post_training/common/validator.py @@ -9,6 +9,9 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from typing import Any + from llama_stack.apis.common.type_system import ( ChatCompletionInputType, DialogType, @@ -20,7 +23,7 @@ from llama_stack.providers.utils.common.data_schema_validator import ( validate_dataset_schema, ) -EXPECTED_DATASET_SCHEMA = { +EXPECTED_DATASET_SCHEMA: dict[str, list[dict[str, Any]]] = { "instruct": [ { ColumnName.chat_completion_input.value: ChatCompletionInputType(), @@ -41,6 +44,9 @@ async def validate_input_dataset_schema( dataset_type: str, ) -> None: dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id) + if not dataset_def: + raise ValueError(f"Dataset {dataset_id} does not exist.") + if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0: raise ValueError(f"Dataset {dataset_id} does not have a schema defined.") diff --git a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py index 64d61b053..fcadd0884 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/checkpointer.py @@ -37,7 +37,7 @@ class TorchtuneCheckpointer: checkpoint_files: List[str], output_dir: str, model_type: str, - ) -> None: + ): # Fail fast if ``checkpoint_files`` is invalid # TODO: support loading more than one file if len(checkpoint_files) != 1: @@ -58,7 +58,7 @@ class TorchtuneCheckpointer: """ Load Meta checkpoint from file. Currently only loading from a single file is supported. """ - state_dict: Dict[str:Any] = {} + state_dict: Dict[str, Any] = {} model_state_dict = safe_torch_load(self._checkpoint_path) if self._model_type == ModelType.LLAMA3_VISION: from torchtune.models.llama3_2_vision._convert_weights import ( @@ -85,10 +85,10 @@ class TorchtuneCheckpointer: state_dict: Dict[str, Any], epoch: int, adapter_only: bool = False, - checkpoint_format: str = "meta", + checkpoint_format: str | None = None, ) -> str: model_file_path = Path(self._output_dir) / f"{self._model_id}-{self._training_algorithm}-{epoch}" - if checkpoint_format == "meta": + if checkpoint_format == "meta" or checkpoint_format is None: self._save_meta_format_checkpoint(model_file_path, state_dict, adapter_only) elif checkpoint_format == "huggingface": # Note: for saving hugging face format checkpoints, we only suppport saving adapter weights now diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 98e16f9d7..f8a1c0436 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -10,7 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Callable, Dict +from typing import Callable, Dict import torch from pydantic import BaseModel @@ -25,10 +25,13 @@ from llama_stack.apis.post_training import DatasetFormat from llama_stack.models.llama.datatypes import Model from llama_stack.models.llama.sku_list import resolve_model +BuildLoraModelCallable = Callable[..., torch.nn.Module] +BuildTokenizerCallable = Callable[..., Llama3Tokenizer] + class ModelConfig(BaseModel): - model_definition: Any - tokenizer_type: Any + model_definition: BuildLoraModelCallable + tokenizer_type: BuildTokenizerCallable checkpoint_type: str @@ -51,10 +54,6 @@ DATA_FORMATS: Dict[str, Transform] = { } -BuildLoraModelCallable = Callable[..., torch.nn.Module] -BuildTokenizerCallable = Callable[..., Llama3Tokenizer] - - def _validate_model_id(model_id: str) -> Model: model = resolve_model(model_id) if model is None or model.core_model_id.value not in MODEL_CONFIGS: diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index b556b59a6..050996860 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -55,7 +55,7 @@ class SFTDataset(Dataset): if "messages" in transformed_sample: validate_messages(transformed_sample["messages"]) - tokenized_dict = self._model_transform(transformed_sample) + tokenized_dict: dict[str, Any] = self._model_transform(transformed_sample) if not ("tokens" in tokenized_dict and "mask" in tokenized_dict): keys_str = ", ".join(tokenized_dict.keys()) diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index 0f89b4064..edc1ceb90 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -37,10 +37,10 @@ from llama_stack.apis.common.training_types import PostTrainingMetric from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.post_training import ( - AlgorithmConfig, Checkpoint, LoraFinetuningConfig, OptimizerConfig, + QATFinetuningConfig, TrainingConfig, ) from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR @@ -73,6 +73,9 @@ class LoraFinetuningSingleDevice: # Currently logging only logs limited training metrics to local disk # will figure out more loggings and how it works with telemetry in future PRs + + _checkpointer: TorchtuneCheckpointer + def __init__( self, config: TorchtunePostTrainingConfig, @@ -82,7 +85,7 @@ class LoraFinetuningSingleDevice: logger_config: Dict[str, Any], model: str, checkpoint_dir: Optional[str], - algorithm_config: Optional[AlgorithmConfig], + algorithm_config: LoraFinetuningConfig | QATFinetuningConfig | None, datasetio_api: DatasetIO, datasets_api: Datasets, ) -> None: @@ -109,12 +112,12 @@ class LoraFinetuningSingleDevice: return str(checkpoint_dir) if checkpoint_dir and checkpoint_dir != "null": - self.checkpoint_dir = config.checkpoint_dir + self.checkpoint_dir = checkpoint_dir else: - model = resolve_model(self.model_id) - if model is None: + model_obj = resolve_model(self.model_id) + if model_obj is None: raise ValueError(f"{self.model_id} not found. Your model id should be in the llama models SKU list") - self.checkpoint_dir = model_checkpoint_dir(model) + self.checkpoint_dir = model_checkpoint_dir(model_obj) self._output_dir = str(DEFAULT_CHECKPOINT_DIR) self._checkpoint_format = config.checkpoint_format @@ -135,16 +138,16 @@ class LoraFinetuningSingleDevice: self.max_validation_steps = training_config.max_validation_steps self._clip_grad_norm = 1.0 - self._enable_activation_checkpointing = ( - (training_config.efficiency_config.enable_activation_checkpointing) - if training_config.efficiency_config - else False - ) - self._enable_activation_offloading = ( - (training_config.efficiency_config.enable_activation_offloading) - if training_config.efficiency_config - else False - ) + + self._enable_activation_checkpointing = False + self._enable_activation_offloading = False + if training_config.efficiency_config: + if training_config.efficiency_config.enable_activation_checkpointing: + self._enable_activation_checkpointing = ( + training_config.efficiency_config.enable_activation_checkpointing + ) + if training_config.efficiency_config.enable_activation_offloading: + self._enable_activation_offloading = training_config.efficiency_config.enable_activation_offloading self.datasetio_api = datasetio_api self.datasets_api = datasets_api @@ -451,12 +454,12 @@ class LoraFinetuningSingleDevice: """ # Initialize tokens count and running loss (for grad accumulation) t0 = time.perf_counter() - running_loss = 0 + running_loss: float = 0.0 num_tokens = 0 # training artifacts checkpoints = [] - memory_stats = {} + memory_stats: Dict[str, Any] = {} # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): @@ -484,7 +487,7 @@ class LoraFinetuningSingleDevice: # Loss is normalized by default so we multiply by the number of tokens # This way we can normalize by the total number of tokens if we're accumulating gradients current_loss = await self._loss_step(batch) * current_num_tokens - running_loss += current_loss + running_loss += current_loss.detach().item() current_loss.backward() # Step with optimizer @@ -500,7 +503,7 @@ class LoraFinetuningSingleDevice: # Update the number of steps when the weights are updated self.global_step += 1 - loss_to_log = running_loss.item() / num_tokens + loss_to_log = running_loss / num_tokens pbar.update(1) pbar.set_description(f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}") @@ -523,7 +526,7 @@ class LoraFinetuningSingleDevice: ) # Reset running stats for the next step - running_loss = 0 + running_loss = 0.0 num_tokens = 0 t0 = time.perf_counter() diff --git a/pyproject.toml b/pyproject.toml index f57b91462..107150cee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -228,10 +228,6 @@ exclude = [ "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/post_training/common/validator\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/checkpointer\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/common/utils\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/datasets/sft\\.py$", - "^llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device\\.py$", "^llama_stack/providers/inline/post_training/torchtune/post_training\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/", From 9c8e88ea9ca756dd10b2db0e68a4166e35c6e5ff Mon Sep 17 00:00:00 2001 From: Sarthak Deshpande <60317842+cheesecake100201@users.noreply.github.com> Date: Wed, 19 Mar 2025 03:30:48 +0530 Subject: [PATCH 11/19] fix: Fixed import errors for UI and playground (#1666) # What does this PR do? Fixed import errors for playground and ui --------- Co-authored-by: sarthakdeshpande --- .../distribution/ui/page/distribution/datasets.py | 3 ++- .../distribution/ui/page/distribution/eval_tasks.py | 3 ++- .../distribution/ui/page/distribution/models.py | 3 ++- .../distribution/ui/page/distribution/providers.py | 3 ++- .../distribution/ui/page/distribution/resources.py | 13 +++++++------ .../ui/page/distribution/scoring_functions.py | 3 ++- .../distribution/ui/page/distribution/shields.py | 3 ++- .../distribution/ui/page/distribution/vector_dbs.py | 3 ++- .../distribution/ui/page/evaluations/app_eval.py | 5 +++-- .../distribution/ui/page/evaluations/native_eval.py | 3 ++- llama_stack/distribution/ui/page/playground/chat.py | 3 ++- llama_stack/distribution/ui/page/playground/rag.py | 7 ++++--- 12 files changed, 32 insertions(+), 20 deletions(-) diff --git a/llama_stack/distribution/ui/page/distribution/datasets.py b/llama_stack/distribution/ui/page/distribution/datasets.py index b583c93fd..6842b29a7 100644 --- a/llama_stack/distribution/ui/page/distribution/datasets.py +++ b/llama_stack/distribution/ui/page/distribution/datasets.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def datasets(): diff --git a/llama_stack/distribution/ui/page/distribution/eval_tasks.py b/llama_stack/distribution/ui/page/distribution/eval_tasks.py index 1428ae9ab..492be4700 100644 --- a/llama_stack/distribution/ui/page/distribution/eval_tasks.py +++ b/llama_stack/distribution/ui/page/distribution/eval_tasks.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def benchmarks(): diff --git a/llama_stack/distribution/ui/page/distribution/models.py b/llama_stack/distribution/ui/page/distribution/models.py index 3141c1627..f29459098 100644 --- a/llama_stack/distribution/ui/page/distribution/models.py +++ b/llama_stack/distribution/ui/page/distribution/models.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def models(): diff --git a/llama_stack/distribution/ui/page/distribution/providers.py b/llama_stack/distribution/ui/page/distribution/providers.py index 9aeb7f2a5..c660cb986 100644 --- a/llama_stack/distribution/ui/page/distribution/providers.py +++ b/llama_stack/distribution/ui/page/distribution/providers.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def providers(): diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py index 684270d4d..5e10e6e80 100644 --- a/llama_stack/distribution/ui/page/distribution/resources.py +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -4,14 +4,15 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from page.distribution.benchmarks import benchmarks -from page.distribution.datasets import datasets -from page.distribution.models import models -from page.distribution.scoring_functions import scoring_functions -from page.distribution.shields import shields -from page.distribution.vector_dbs import vector_dbs from streamlit_option_menu import option_menu +from llama_stack.distribution.ui.page.distribution.datasets import datasets +from llama_stack.distribution.ui.page.distribution.eval_tasks import benchmarks +from llama_stack.distribution.ui.page.distribution.models import models +from llama_stack.distribution.ui.page.distribution.scoring_functions import scoring_functions +from llama_stack.distribution.ui.page.distribution.shields import shields +from llama_stack.distribution.ui.page.distribution.vector_dbs import vector_dbs + def resources_page(): options = [ diff --git a/llama_stack/distribution/ui/page/distribution/scoring_functions.py b/llama_stack/distribution/ui/page/distribution/scoring_functions.py index 6a2a08c6d..193146356 100644 --- a/llama_stack/distribution/ui/page/distribution/scoring_functions.py +++ b/llama_stack/distribution/ui/page/distribution/scoring_functions.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def scoring_functions(): diff --git a/llama_stack/distribution/ui/page/distribution/shields.py b/llama_stack/distribution/ui/page/distribution/shields.py index b5ed27ef9..67d66d64f 100644 --- a/llama_stack/distribution/ui/page/distribution/shields.py +++ b/llama_stack/distribution/ui/page/distribution/shields.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def shields(): diff --git a/llama_stack/distribution/ui/page/distribution/vector_dbs.py b/llama_stack/distribution/ui/page/distribution/vector_dbs.py index 1c9d06e8d..49a4f25bb 100644 --- a/llama_stack/distribution/ui/page/distribution/vector_dbs.py +++ b/llama_stack/distribution/ui/page/distribution/vector_dbs.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def vector_dbs(): diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py index 26bc28451..d7bc6388c 100644 --- a/llama_stack/distribution/ui/page/evaluations/app_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -8,8 +8,9 @@ import json import pandas as pd import streamlit as st -from modules.api import llama_stack_api -from modules.utils import process_dataset + +from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.distribution.ui.modules.utils import process_dataset def application_evaluation_page(): diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index 7c39adc4a..97f875e17 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -8,7 +8,8 @@ import json import pandas as pd import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api def select_benchmark_1(): diff --git a/llama_stack/distribution/ui/page/playground/chat.py b/llama_stack/distribution/ui/page/playground/chat.py index e69f559db..8e7345169 100644 --- a/llama_stack/distribution/ui/page/playground/chat.py +++ b/llama_stack/distribution/ui/page/playground/chat.py @@ -5,7 +5,8 @@ # the root directory of this source tree. import streamlit as st -from modules.api import llama_stack_api + +from llama_stack.distribution.ui.modules.api import llama_stack_api # Sidebar configurations with st.sidebar: diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 7ee934fb7..e2f451668 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -7,9 +7,10 @@ import streamlit as st from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger -from llama_stack_client.types.memory_insert_params import Document -from modules.api import llama_stack_api -from modules.utils import data_url_from_file +from llama_stack_client.types.shared.document import Document + +from llama_stack.distribution.ui.modules.api import llama_stack_api +from llama_stack.distribution.ui.modules.utils import data_url_from_file def rag_chat_page(): From b79e0435de6be38a6dd4061b8748939305815750 Mon Sep 17 00:00:00 2001 From: yyymeta <123776235+yyymeta@users.noreply.github.com> Date: Tue, 18 Mar 2025 16:17:29 -0700 Subject: [PATCH 12/19] fix: avoid tensor memory error (#1688) # What does this PR do? we randomly get errors like the following, it's most likely due to accessing an object that is already deallocated ``` E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] Traceback (most recent call last): E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/multiprocessing/spawn.py", line 90, in _wrap E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] fn(i, *args) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/api.py", line 611, in _wrap E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] ret = record(fn)(*args_) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] return f(*args, **kwargs) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/internal-llama-stack/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py", line 249, in worker_process_entrypoint E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] task = req_gen.send(result) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/internal-llama-stack/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py", line 156, in retrieve_requests E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] torch.distributed.broadcast_object_list( E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 81, in wrapper E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] return func(*args, **kwargs) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 3504, in broadcast_object_list E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] object_list[i] = _tensor_to_object(obj_view, obj_size, group) E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] File "/home/yyy/.conda/envs/myenv/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2961, in _tensor_to_object E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] return _unpickler(io.BytesIO(buf)).load() E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] EOFError: Ran out of input E0318 12:55:24.472000 1562188 site-packages/torch/distributed/elastic/multiprocessing/api.py:732] Process SpawnProcess-1: Traceback (most recent call last): ``` ## Test Plan start server ``` llama-stack-client eval run-benchmark mmmu_v1 --model-id meta-llama/Llama-4-17B-Omni-Instruct --output-dir /tmp/mmmu_standard --num-examples 30 ``` [//]: # (## Documentation) --- .../inline/inference/meta_reference/parallel_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 738f9ddcd..e8767c2ff 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -10,6 +10,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import copy import json import logging import multiprocessing @@ -213,7 +214,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage def parse_message(json_str: str) -> ProcessingMessage: data = json.loads(json_str) - return ProcessingMessageWrapper(**data).payload + return copy.deepcopy(ProcessingMessageWrapper(**data).payload) def worker_process_entrypoint( From 5b39d5a76af13f055974c5cd1d66a31c92f01ccd Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Tue, 18 Mar 2025 16:24:18 -0700 Subject: [PATCH 13/19] feat(auth, rfc): Add support for Bearer (api_key) Authentication (#1626) This PR adds support (or is a proposal for) for supporting API KEY authentication on the Llama Stack server end. `llama-stack-client` already supports accepting an api_key parameter and passes it down through every request as an `Authentication: ` header. Currently, Llama Stack does not propose APIs for handling authentication or authorization for resources of any kind. Given that, and the fact that any deployment will typically have _some_ authentication system present, we simply adopt a delegation mechanism: delegate to an HTTPS endpoint performing key management / authentication. It is configured via: ```yaml server: auth: endpoint: <...> ``` in the run.yaml configuration. ## How It Works When authentication is enabled: 1. Every API request must include an `Authorization: Bearer ` header 2. The server will send a _POST_ validation request to the configured endpoint with the following payload: ```json { "api_key": "", "request": { "path": "/api/path", "headers": { "header1": "value1", ... }, "params": { "param1": "value1", ... } } } ``` 3. If the authentication endpoint returns a 200 status code, the request is allowed to proceed 4. If the authentication endpoint returns any other status code, a 401 Unauthorized response is returned ## Test Plan Unit tests --- llama_stack/distribution/datatypes.py | 11 ++ llama_stack/distribution/server/auth.py | 69 ++++++++++++ llama_stack/distribution/server/server.py | 6 ++ tests/unit/server/test_auth.py | 124 ++++++++++++++++++++++ 4 files changed, 210 insertions(+) create mode 100644 llama_stack/distribution/server/auth.py create mode 100644 tests/unit/server/test_auth.py diff --git a/llama_stack/distribution/datatypes.py b/llama_stack/distribution/datatypes.py index 7e1d8c016..e16e047e5 100644 --- a/llama_stack/distribution/datatypes.py +++ b/llama_stack/distribution/datatypes.py @@ -125,6 +125,13 @@ class LoggingConfig(BaseModel): ) +class AuthenticationConfig(BaseModel): + endpoint: str = Field( + ..., + description="Endpoint URL to validate authentication tokens", + ) + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -140,6 +147,10 @@ class ServerConfig(BaseModel): default=None, description="Path to TLS key file for HTTPS", ) + auth: Optional[AuthenticationConfig] = Field( + default=None, + description="Authentication configuration for the server", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/distribution/server/auth.py b/llama_stack/distribution/server/auth.py new file mode 100644 index 000000000..bb577bae5 --- /dev/null +++ b/llama_stack/distribution/server/auth.py @@ -0,0 +1,69 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from urllib.parse import parse_qs + +import httpx + +from llama_stack.log import get_logger + +logger = get_logger(name=__name__, category="auth") + + +class AuthenticationMiddleware: + def __init__(self, app, auth_endpoint): + self.app = app + self.auth_endpoint = auth_endpoint + + async def __call__(self, scope, receive, send): + if scope["type"] == "http": + headers = dict(scope.get("headers", [])) + auth_header = headers.get(b"authorization", b"").decode() + + if not auth_header or not auth_header.startswith("Bearer "): + return await self._send_auth_error(send, "Missing or invalid Authorization header") + + api_key = auth_header.split("Bearer ", 1)[1] + + path = scope.get("path", "") + request_headers = {k.decode(): v.decode() for k, v in headers.items()} + + query_string = scope.get("query_string", b"").decode() + params = parse_qs(query_string) + + auth_data = { + "api_key": api_key, + "request": { + "path": path, + "headers": request_headers, + "params": params, + }, + } + + # Validate with authentication endpoint + try: + async with httpx.AsyncClient() as client: + response = await client.post(self.auth_endpoint, json=auth_data) + if response.status_code != 200: + logger.warning(f"Authentication failed: {response.status_code}") + return await self._send_auth_error(send, "Authentication failed") + except Exception: + logger.exception("Error during authentication") + return await self._send_auth_error(send, "Authentication service error") + + return await self.app(scope, receive, send) + + async def _send_auth_error(self, send, message): + await send( + { + "type": "http.response.start", + "status": 401, + "headers": [[b"content-type", b"application/json"]], + } + ) + error_msg = json.dumps({"error": {"message": message}}).encode() + await send({"type": "http.response.body", "body": error_msg}) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index b37b3a007..460acbc87 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import ( start_trace, ) +from .auth import AuthenticationMiddleware from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent @@ -351,6 +352,11 @@ def main(): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) + # Add authentication middleware if configured + if config.server.auth and config.server.auth.endpoint: + logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}") + app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint) + try: impls = asyncio.run(construct_stack(config)) except InvalidProviderError as e: diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py new file mode 100644 index 000000000..70f08dbd6 --- /dev/null +++ b/tests/unit/server/test_auth.py @@ -0,0 +1,124 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from llama_stack.distribution.server.auth import AuthenticationMiddleware + + +@pytest.fixture +def mock_auth_endpoint(): + return "http://mock-auth-service/validate" + + +@pytest.fixture +def valid_api_key(): + return "valid_api_key_12345" + + +@pytest.fixture +def invalid_api_key(): + return "invalid_api_key_67890" + + +@pytest.fixture +def app(mock_auth_endpoint): + app = FastAPI() + app.add_middleware(AuthenticationMiddleware, auth_endpoint=mock_auth_endpoint) + + @app.get("/test") + def test_endpoint(): + return {"message": "Authentication successful"} + + return app + + +@pytest.fixture +def client(app): + return TestClient(app) + + +async def mock_post_success(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 200 + return mock_response + + +async def mock_post_failure(*args, **kwargs): + mock_response = AsyncMock() + mock_response.status_code = 401 + return mock_response + + +async def mock_post_exception(*args, **kwargs): + raise Exception("Connection error") + + +def test_missing_auth_header(client): + response = client.get("/test") + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +def test_invalid_auth_header_format(client): + response = client.get("/test", headers={"Authorization": "InvalidFormat token123"}) + assert response.status_code == 401 + assert "Missing or invalid Authorization header" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_success) +def test_valid_authentication(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 200 + assert response.json() == {"message": "Authentication successful"} + + +@patch("httpx.AsyncClient.post", new=mock_post_failure) +def test_invalid_authentication(client, invalid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {invalid_api_key}"}) + assert response.status_code == 401 + assert "Authentication failed" in response.json()["error"]["message"] + + +@patch("httpx.AsyncClient.post", new=mock_post_exception) +def test_auth_service_error(client, valid_api_key): + response = client.get("/test", headers={"Authorization": f"Bearer {valid_api_key}"}) + assert response.status_code == 401 + assert "Authentication service error" in response.json()["error"]["message"] + + +def test_auth_request_payload(client, valid_api_key, mock_auth_endpoint): + with patch("httpx.AsyncClient.post") as mock_post: + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_post.return_value = mock_response + + client.get( + "/test?param1=value1¶m2=value2", + headers={ + "Authorization": f"Bearer {valid_api_key}", + "User-Agent": "TestClient", + "Content-Type": "application/json", + }, + ) + + # Check that the auth endpoint was called with the correct payload + call_args = mock_post.call_args + assert call_args is not None + + url, kwargs = call_args[0][0], call_args[1] + assert url == mock_auth_endpoint + + payload = kwargs["json"] + assert payload["api_key"] == valid_api_key + assert payload["request"]["path"] == "/test" + assert "authorization" in payload["request"]["headers"] + assert "param1" in payload["request"]["params"] + assert "param2" in payload["request"]["params"] From 7c0448456ed1dbca785606c8bde8797cb1c82704 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Wed, 19 Mar 2025 00:17:22 -0400 Subject: [PATCH 14/19] docs: Remove mentions of focus on Llama models (#1690) # What does this PR do? This is a follow-up of https://github.com/meta-llama/llama-stack/issues/965 to avoid mentioning exclusive support on Llama models. --------- Signed-off-by: Yuan Tang --- docs/source/index.md | 2 -- docs/source/introduction/index.md | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/docs/source/index.md b/docs/source/index.md index 0a8fcb30c..12a27bd2b 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -15,8 +15,6 @@ Llama Stack defines and standardizes the core building blocks needed to bring ge - **Multiple developer interfaces** like CLI and SDKs for Python, Node, iOS, and Android - **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack -We focus on making it easy to build production applications with the Llama model family - from the latest Llama 3.3 to specialized models like Llama Guard for safety. - ```{image} ../_static/llama-stack.png :alt: Llama Stack :width: 400px diff --git a/docs/source/introduction/index.md b/docs/source/introduction/index.md index 686f44cc4..5ffa5e68d 100644 --- a/docs/source/introduction/index.md +++ b/docs/source/introduction/index.md @@ -48,7 +48,7 @@ Llama Stack addresses these challenges through a service-oriented, API-first app **Robust Ecosystem** - Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies). -- Ecosystem offers tailored infrastructure, software, and services for deploying Llama models. +- Ecosystem offers tailored infrastructure, software, and services for deploying a variety of models. ### Our Philosophy @@ -57,7 +57,6 @@ Llama Stack addresses these challenges through a service-oriented, API-first app - **Composability**: Every component is independent but works together seamlessly - **Production Ready**: Built for real-world applications, not just demos - **Turnkey Solutions**: Easy to deploy built in solutions for popular deployment scenarios -- **Llama First**: Explicit focus on Meta's Llama models and partnering ecosystem With Llama Stack, you can focus on building your application while we handle the infrastructure complexity, essential capabilities, and provider integrations. From 5418e63919e11b63fdb833a11910ab1b54858aa7 Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Wed, 19 Mar 2025 10:59:17 -0600 Subject: [PATCH 15/19] chore: Add triagers list #1561 (#1701) # What does this PR do? Adds triagers list ## Closes #1561 ## Documentation Was provided here: https://github.com/meta-llama/llama-stack/pull/1621 Signed-off-by: Francisco Javier Arceo --- .github/TRIAGERS.md | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 .github/TRIAGERS.md diff --git a/.github/TRIAGERS.md b/.github/TRIAGERS.md new file mode 100644 index 000000000..d4ef6d1ac --- /dev/null +++ b/.github/TRIAGERS.md @@ -0,0 +1,2 @@ +# This file documents Triage members in the Llama Stack community +@franciscojavierarceo @leseb From 113f3a259c91bd74881be7434a55e36f860f7e33 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 19 Mar 2025 10:16:00 -0700 Subject: [PATCH 16/19] docs: add documentation for RAGDocument (#1693) # What does this PR do? ## Test Plan --- docs/_static/llama-stack-spec.html | 15 ++++++++++----- docs/_static/llama-stack-spec.yaml | 6 ++++++ llama_stack/apis/tools/rag_tool.py | 9 +++++++++ 3 files changed, 25 insertions(+), 5 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 2362dfa53..b32b7cfdf 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -7787,7 +7787,8 @@ "type": "object", "properties": { "document_id": { - "type": "string" + "type": "string", + "description": "The unique identifier for the document." }, "content": { "oneOf": [ @@ -7806,10 +7807,12 @@ { "$ref": "#/components/schemas/URL" } - ] + ], + "description": "The content of the document." }, "mime_type": { - "type": "string" + "type": "string", + "description": "The MIME type of the document." }, "metadata": { "type": "object", @@ -7834,7 +7837,8 @@ "type": "object" } ] - } + }, + "description": "Additional metadata for the document." } }, "additionalProperties": false, @@ -7843,7 +7847,8 @@ "content", "metadata" ], - "title": "RAGDocument" + "title": "RAGDocument", + "description": "A document to be used for document ingestion in the RAG Tool." }, "InsertRequest": { "type": "object", diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 38e08e41c..eb5d9722e 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -5375,6 +5375,7 @@ components: properties: document_id: type: string + description: The unique identifier for the document. content: oneOf: - type: string @@ -5383,8 +5384,10 @@ components: items: $ref: '#/components/schemas/InterleavedContentItem' - $ref: '#/components/schemas/URL' + description: The content of the document. mime_type: type: string + description: The MIME type of the document. metadata: type: object additionalProperties: @@ -5395,12 +5398,15 @@ components: - type: string - type: array - type: object + description: Additional metadata for the document. additionalProperties: false required: - document_id - content - metadata title: RAGDocument + description: >- + A document to be used for document ingestion in the RAG Tool. InsertRequest: type: object properties: diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 2b9ef10d8..671e19619 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -17,6 +17,15 @@ from llama_stack.schema_utils import json_schema_type, register_schema, webmetho @json_schema_type class RAGDocument(BaseModel): + """ + A document to be used for document ingestion in the RAG Tool. + + :param document_id: The unique identifier for the document. + :param content: The content of the document. + :param mime_type: The MIME type of the document. + :param metadata: Additional metadata for the document. + """ + document_id: str content: InterleavedContent | URL mime_type: str | None = None From 65ca85ba6b938bf14a848200ebbf0ad111c837f4 Mon Sep 17 00:00:00 2001 From: Hardik Shah Date: Wed, 19 Mar 2025 10:36:19 -0700 Subject: [PATCH 17/19] fix: Updating `ToolCall.arguments` to allow for json strings that can be decoded on client side (#1685) ### What does this PR do? Currently, `ToolCall.arguments` is a `Dict[str, RecursiveType]`. However, on the client SDK side -- the `RecursiveType` gets deserialized into a number ( both int and float get collapsed ) and hence when params are `int` they get converted to float which might break client side tools that might be doing type checking. Closes: https://github.com/meta-llama/llama-stack/issues/1683 ### Test Plan Stainless changes -- https://github.com/meta-llama/llama-stack-client-python/pull/204 ``` pytest -s -v --stack-config=fireworks tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.1-8B-Instruct ``` --- docs/_static/llama-stack-spec.html | 132 ++++++++++-------- docs/_static/llama-stack-spec.yaml | 52 +++---- llama_stack/models/llama/datatypes.py | 9 +- .../models/llama/llama3/chat_format.py | 9 +- .../models/llama/llama3/template_data.py | 7 +- .../providers/inline/inference/vllm/vllm.py | 1 + .../remote/inference/sambanova/sambanova.py | 10 +- .../providers/remote/inference/vllm/vllm.py | 8 +- .../utils/inference/openai_compat.py | 14 +- tests/unit/models/test_prompt_adapter.py | 5 +- 10 files changed, 137 insertions(+), 110 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index b32b7cfdf..eb626fc44 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4159,70 +4159,80 @@ ] }, "arguments": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" - }, - { - "type": "array", - "items": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" + "oneOf": [ + { + "type": "string" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + }, + { + "type": "array", + "items": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] } - ] - } - }, - { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "string" - }, - { - "type": "integer" - }, - { - "type": "number" - }, - { - "type": "boolean" - }, - { - "type": "null" + }, + { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "string" + }, + { + "type": "integer" + }, + { + "type": "number" + }, + { + "type": "boolean" + }, + { + "type": "null" + } + ] } - ] - } + } + ] } - ] - } + } + ] + }, + "arguments_json": { + "type": "string" } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index eb5d9722e..fa6920381 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2864,30 +2864,34 @@ components: title: BuiltinTool - type: string arguments: - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: array - items: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' - - type: object - additionalProperties: - oneOf: - - type: string - - type: integer - - type: number - - type: boolean - - type: 'null' + oneOf: + - type: string + - type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: array + items: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + - type: object + additionalProperties: + oneOf: + - type: string + - type: integer + - type: number + - type: boolean + - type: 'null' + arguments_json: + type: string additionalProperties: false required: - call_id diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py index b25bf0ea9..9842d7980 100644 --- a/llama_stack/models/llama/datatypes.py +++ b/llama_stack/models/llama/datatypes.py @@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]] class ToolCall(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] - arguments: Dict[str, RecursiveType] + # Plan is to deprecate the Dict in favor of a JSON string + # that is parsed on the client side instead of trying to manage + # the recursive type here. + # Making this a union so that client side can start prepping for this change. + # Eventually, we will remove both the Dict and arguments_json field, + # and arguments will just be a str + arguments: Union[str, Dict[str, RecursiveType]] + arguments_json: Optional[str] = None @field_validator("tool_name", mode="before") @classmethod diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 011ccb02a..2862f8558 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -12,6 +12,7 @@ # the top-level of this source tree. import io +import json import uuid from dataclasses import dataclass from typing import Dict, List, Optional, Tuple @@ -203,9 +204,10 @@ class ChatFormat: # This code tries to handle that case if tool_name in BuiltinTool.__members__: tool_name = BuiltinTool[tool_name] - tool_arguments = { - "query": list(tool_arguments.values())[0], - } + if isinstance(tool_arguments, dict): + tool_arguments = { + "query": list(tool_arguments.values())[0], + } else: builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content) if builtin_tool_info is not None: @@ -229,6 +231,7 @@ class ChatFormat: call_id=call_id, tool_name=tool_name, arguments=tool_arguments, + arguments_json=json.dumps(tool_arguments), ) ) content = "" diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py index aa16aa009..076b4adb4 100644 --- a/llama_stack/models/llama/llama3/template_data.py +++ b/llama_stack/models/llama/llama3/template_data.py @@ -11,11 +11,8 @@ # top-level folder for each specific model found within the models/ directory at # the top-level of this source tree. -from llama_stack.models.llama.datatypes import ( - BuiltinTool, - StopReason, - ToolCall, -) + +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from .prompt_templates import ( BuiltinToolGenerator, diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index b59df13d0..256e0f821 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): tool_name=t.function.name, # vLLM function args come back as a string. Llama Stack expects JSON. arguments=json.loads(t.function.arguments), + arguments_json=t.function.arguments, ) for t in vllm_message.tool_calls ], diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index a5e17c2a3..635a42d38 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import ( TopKSamplingStrategy, TopPSamplingStrategy, ) -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, -) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_stream_response, ) @@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): if not tool_calls: return [] - for call in tool_calls: - call_function_arguments = json.loads(call.function.arguments) - compitable_tool_calls = [ ToolCall( call_id=call.id, tool_name=call.function.name, - arguments=call_function_arguments, + arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index f940de7ba..eda1a179c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response( if not tool_calls: return [] - call_function_arguments = None - for call in tool_calls: - call_function_arguments = json.loads(call.function.arguments) - return [ ToolCall( call_id=call.id, tool_name=call.function.name, - arguments=call_function_arguments, + arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] @@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response( call_id=tool_call_buf.call_id, tool_name=tool_call_buf.tool_name, arguments=args, + arguments_json=args_str, ), parse_status=ToolCallParseStatus.succeeded, ), diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 2a362f8cb..b264c7312 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new( ) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]: async def impl( content_: InterleavedContent, - ) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]: + ) -> Union[ + str, + OpenAIChatCompletionContentPartParam, + List[OpenAIChatCompletionContentPartParam], + ]: # Llama Stack and OpenAI spec match for str and text input if isinstance(content_, str): return content_ @@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new( OpenAIChatCompletionMessageToolCall( id=tool.call_id, function=OpenAIFunction( - name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value, + name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value), arguments=json.dumps(tool.arguments), ), type="function", @@ -609,6 +613,7 @@ def convert_tool_call( call_id=tool_call.id, tool_name=tool_call.function.name, arguments=json.loads(tool_call.function.arguments), + arguments_json=tool_call.function.arguments, ) except Exception: return UnparseableToolCall( @@ -759,6 +764,7 @@ def _convert_openai_tool_calls( call_id=call.id, tool_name=call.function.name, arguments=json.loads(call.function.arguments), + arguments_json=call.function.arguments, ) for call in tool_calls ] @@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream( # ChatCompletionResponseEvent only supports one per stream if len(choice.delta.tool_calls) > 1: warnings.warn( - "multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2 + "multiple tool calls found in a single delta, using the first, ignoring the rest", + stacklevel=2, ) if not enable_incremental_tool_calls: @@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream( call_id=buffer["call_id"], tool_name=buffer["name"], arguments=arguments, + arguments_json=buffer["arguments"], ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index c3755e2cb..0e2780e50 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): request.model = MODEL request.tool_config.tool_prompt_format = ToolPromptFormat.json prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) + self.assertIn( + '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', + prompt, + ) async def test_user_provided_system_message(self): content = "Hello !" From 6949bd19998d761003958486e38a2bd53c231d58 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Wed, 19 Mar 2025 17:46:37 +0000 Subject: [PATCH 18/19] fix: Call pandas.read_* in a seperate thread (#1698) These block on io reads which in turn block the server. Move them to their own thread. Closes: #1697 # What does this PR do? To avoid blocking the main eventloop, updates datasetio/localfs to load data in a seperate thread Signed-off-by: Derek Higgins --- .../providers/inline/datasetio/localfs/datasetio.py | 8 ++++---- llama_stack/providers/utils/datasetio/url_utils.py | 10 +++++++--- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index cf4bf7fec..f489739bf 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -35,12 +35,12 @@ class PandasDataframeDataset: else: return self.df.iloc[idx].to_dict() - def load(self) -> None: + async def load(self) -> None: if self.df is not None: return if self.dataset_def.source.type == "uri": - self.df = get_dataframe_from_uri(self.dataset_def.source.uri) + self.df = await get_dataframe_from_uri(self.dataset_def.source.uri) elif self.dataset_def.source.type == "rows": self.df = pandas.DataFrame(self.dataset_def.source.rows) else: @@ -95,7 +95,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): ) -> IterrowsResponse: dataset_def = self.dataset_infos[dataset_id] dataset_impl = PandasDataframeDataset(dataset_def) - dataset_impl.load() + await dataset_impl.load() start_index = start_index or 0 @@ -114,7 +114,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: dataset_def = self.dataset_infos[dataset_id] dataset_impl = PandasDataframeDataset(dataset_def) - dataset_impl.load() + await dataset_impl.load() new_rows_df = pandas.DataFrame(rows) dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True) diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py index 6a544ea49..386ee736d 100644 --- a/llama_stack/providers/utils/datasetio/url_utils.py +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import base64 import io from urllib.parse import unquote @@ -13,12 +14,15 @@ import pandas from llama_stack.providers.utils.memory.vector_store import parse_data_url -def get_dataframe_from_uri(uri: str): +async def get_dataframe_from_uri(uri: str): df = None if uri.endswith(".csv"): - df = pandas.read_csv(uri) + # Moving to its own thread to avoid io from blocking the eventloop + # This isn't ideal as it moves more then just the IO to a new thread + # but it is as close as we can easly get + df = await asyncio.to_thread(pandas.read_csv, uri) elif uri.endswith(".xlsx"): - df = pandas.read_excel(uri) + df = await asyncio.to_thread(pandas.read_excel, uri) elif uri.startswith("data:"): parts = parse_data_url(uri) data = parts["data"] From ab777ef5cd919c73f77d9a7af8d3c5f03ab57098 Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 19 Mar 2025 11:27:11 -0700 Subject: [PATCH 19/19] fix: fix open-benchmark template (#1695) ## What does this PR do? open-benchmark templated is broken after the datasets api refactor due to 2 reasons - provider_id and provider_resource_id are no longer needed - the type in run.yaml will be resolved as dict this PR is to fix the above 2 issues ## Test spin up a llama stack server successfully with llama stack run `llama_stack/templates/open-benchmark/run.yaml` --- llama_stack/apis/datasets/datasets.py | 2 -- llama_stack/distribution/routers/routing_tables.py | 8 ++++++++ llama_stack/templates/open-benchmark/open_benchmark.py | 5 ----- llama_stack/templates/open-benchmark/run.yaml | 5 ----- 4 files changed, 8 insertions(+), 12 deletions(-) diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 616371c7d..e2c940f64 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -121,8 +121,6 @@ class Dataset(CommonDatasetFields, Resource): class DatasetInput(CommonDatasetFields, BaseModel): dataset_id: str - provider_id: Optional[str] = None - provider_dataset_id: Optional[str] = None class ListDatasetsResponse(BaseModel): diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 5dea942f7..7aef2f8d5 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -20,6 +20,8 @@ from llama_stack.apis.datasets import ( DatasetType, DataSource, ListDatasetsResponse, + RowsDataSource, + URIDataSource, ) from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType from llama_stack.apis.resource import ResourceType @@ -377,6 +379,12 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): metadata: Optional[Dict[str, Any]] = None, dataset_id: Optional[str] = None, ) -> Dataset: + if isinstance(source, dict): + if source["type"] == "uri": + source = URIDataSource.parse_obj(source) + elif source["type"] == "rows": + source = RowsDataSource.parse_obj(source) + if not dataset_id: dataset_id = f"dataset-{str(uuid.uuid4())}" diff --git a/llama_stack/templates/open-benchmark/open_benchmark.py b/llama_stack/templates/open-benchmark/open_benchmark.py index b339e8c80..acfbd78d6 100644 --- a/llama_stack/templates/open-benchmark/open_benchmark.py +++ b/llama_stack/templates/open-benchmark/open_benchmark.py @@ -170,7 +170,6 @@ def get_distribution_template() -> DistributionTemplate: default_datasets = [ DatasetInput( dataset_id="simpleqa", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/simpleqa?split=train", @@ -178,7 +177,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="mmlu_cot", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all", @@ -186,7 +184,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="gpqa_cot", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main", @@ -194,7 +191,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="math_500", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/math_500?split=test", @@ -202,7 +198,6 @@ def get_distribution_template() -> DistributionTemplate: ), DatasetInput( dataset_id="bfcl", - provider_id="huggingface", purpose=DatasetPurpose.eval_messages_answer, source=URIDataSource( uri="huggingface://datasets/llamastack/bfcl_v3?split=train", diff --git a/llama_stack/templates/open-benchmark/run.yaml b/llama_stack/templates/open-benchmark/run.yaml index 93f437273..8dbf51472 100644 --- a/llama_stack/templates/open-benchmark/run.yaml +++ b/llama_stack/templates/open-benchmark/run.yaml @@ -164,35 +164,30 @@ datasets: uri: huggingface://datasets/llamastack/simpleqa?split=train metadata: {} dataset_id: simpleqa - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all metadata: {} dataset_id: mmlu_cot - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main metadata: {} dataset_id: gpqa_cot - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/math_500?split=test metadata: {} dataset_id: math_500 - provider_id: huggingface - purpose: eval/messages-answer source: type: uri uri: huggingface://datasets/llamastack/bfcl_v3?split=train metadata: {} dataset_id: bfcl - provider_id: huggingface scoring_fns: [] benchmarks: - dataset_id: simpleqa