From 832c535aafacab758a244a963e5a384f8b16c018 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 20 Feb 2025 18:59:48 -0600 Subject: [PATCH 01/45] feat(providers): add NVIDIA Inference embedding provider and tests (#935) # What does this PR do? add /v1/inference/embeddings implementation to NVIDIA provider **open topics** - - *asymmetric models*. NeMo Retriever includes asymmetric models, which are models that embed differently depending on if the input is destined for storage or lookup against storage. the /v1/inference/embeddings api does not allow the user to indicate the type of embedding to perform. see https://github.com/meta-llama/llama-stack/issues/934 - *truncation*. embedding models typically have a limited context window, e.g. 1024 tokens is common though newer models have 8k windows. when the input is larger than this window the endpoint cannot perform its designed function. two options: 0. return an error so the user can reduce the input size and retry; 1. perform truncation for the user and proceed (common strategies are left or right truncation). many users encounter context window size limits and will struggle to write reliable programs. this struggle is especially acute without access to the model's tokenizer. the /v1/inference/embeddings api does not allow the user to delegate truncation policy. see https://github.com/meta-llama/llama-stack/issues/933 - *dimensions*. "Matryoshka" embedding models are available. they allow users to control the number of embedding dimensions the model produces. this is a critical feature for managing storage constraints. embeddings of 1024 dimensions what achieve 95% recall for an application may not be worth the storage cost if a 512 dimensions can achieve 93% recall. controlling embedding dimensions allows applications to determine their recall and storage tradeoffs. the /v1/inference/embeddings api does not allow the user to control the output dimensions. see https://github.com/meta-llama/llama-stack/issues/932 ## Test Plan - `llama stack run llama_stack/templates/nvidia/run.yaml` - `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model baai/bge-m3` ## Sources Please link relevant resources if necessary. ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [x] Wrote necessary unit or integration tests. --------- Co-authored-by: Ashwin Bharambe --- .../remote_hosted_distro/nvidia.md | 1 + .../remote/inference/nvidia/models.py | 10 ++ .../remote/inference/nvidia/nvidia.py | 39 ++++++- llama_stack/templates/nvidia/nvidia.py | 4 +- llama_stack/templates/nvidia/run.yaml | 7 ++ tests/client-sdk/conftest.py | 12 ++ tests/client-sdk/inference/test_embedding.py | 103 ++++++++++++++++++ 7 files changed, 172 insertions(+), 4 deletions(-) create mode 100644 tests/client-sdk/inference/test_embedding.py diff --git a/docs/source/distributions/remote_hosted_distro/nvidia.md b/docs/source/distributions/remote_hosted_distro/nvidia.md index f352f737e..a1f70e450 100644 --- a/docs/source/distributions/remote_hosted_distro/nvidia.md +++ b/docs/source/distributions/remote_hosted_distro/nvidia.md @@ -36,6 +36,7 @@ The following models are available by default: - `meta-llama/Llama-3.2-3B-Instruct (meta/llama-3.2-3b-instruct)` - `meta-llama/Llama-3.2-11B-Vision-Instruct (meta/llama-3.2-11b-vision-instruct)` - `meta-llama/Llama-3.2-90B-Vision-Instruct (meta/llama-3.2-90b-vision-instruct)` +- `baai/bge-m3 (baai/bge-m3)` ### Prerequisite: API Keys diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py index c432861ee..fa9944be1 100644 --- a/llama_stack/providers/remote/inference/nvidia/models.py +++ b/llama_stack/providers/remote/inference/nvidia/models.py @@ -4,8 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +from llama_stack.apis.models import ModelType from llama_stack.models.llama.datatypes import CoreModelId from llama_stack.providers.utils.inference.model_registry import ( + ProviderModelEntry, build_hf_repo_model_entry, ) @@ -46,6 +48,14 @@ _MODEL_ENTRIES = [ "meta/llama-3.2-90b-vision-instruct", CoreModelId.llama3_2_90b_vision_instruct.value, ), + ProviderModelEntry( + provider_model_id="baai/bge-m3", + model_type=ModelType.embedding, + metadata={ + "embedding_dimensions": 1024, + "context_length": 8192, + }, + ), # TODO(mf): how do we handle Nemotron models? # "Llama3.1-Nemotron-51B-Instruct" -> "meta/llama-3.1-nemotron-51b-instruct", ] diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 824389577..6f38230b2 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -10,6 +10,11 @@ from typing import AsyncIterator, List, Optional, Union from openai import APIConnectionError, AsyncOpenAI +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, + TextContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -19,7 +24,6 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, EmbeddingsResponse, Inference, - InterleavedContent, LogProbConfig, Message, ResponseFormat, @@ -117,9 +121,38 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: - raise NotImplementedError() + if any(content_has_media(content) for content in contents): + raise NotImplementedError("Media is not supported") + + # + # Llama Stack: contents = List[str] | List[InterleavedContentItem] + # -> + # OpenAI: input = str | List[str] + # + # we can ignore str and always pass List[str] to OpenAI + # + flat_contents = [ + item.text if isinstance(item, TextContentItem) else item + for content in contents + for item in (content if isinstance(content, list) else [content]) + ] + input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents] + model = self.get_provider_model_id(model_id) + + response = await self._client.embeddings.create( + model=model, + input=input, + # extra_body={"input_type": "passage"|"query"}, # TODO(mf): how to tell caller's intent? + ) + + # + # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=List[float], ...)], ...) + # -> + # Llama Stack: EmbeddingsResponse(embeddings=List[List[float]]) + # + return EmbeddingsResponse(embeddings=[embedding.embedding for embedding in response.data]) async def chat_completion( self, diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index a505a1b93..56d13a09a 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -41,9 +41,11 @@ def get_distribution_template() -> DistributionTemplate: core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} default_models = [ ModelInput( - model_id=core_model_to_hf_repo[m.llama_model], + model_id=core_model_to_hf_repo[m.llama_model] if m.llama_model else m.provider_model_id, provider_model_id=m.provider_model_id, provider_id="nvidia", + model_type=m.model_type, + metadata=m.metadata, ) for m in _MODEL_ENTRIES ] diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 14fb28354..891fd112a 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -135,6 +135,13 @@ models: provider_id: nvidia provider_model_id: meta/llama-3.2-90b-vision-instruct model_type: llm +- metadata: + embedding_dimensions: 1024 + context_length: 8192 + model_id: baai/bge-m3 + provider_id: nvidia + provider_model_id: baai/bge-m3 + model_type: embedding shields: [] vector_dbs: [] datasets: [] diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index b397f7ab3..efdec6b01 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -58,6 +58,12 @@ def pytest_addoption(parser): default="meta-llama/Llama-Guard-3-1B", help="Specify the safety shield model to use for testing", ) + parser.addoption( + "--embedding-model", + action="store", + default=TEXT_MODEL, + help="Specify the embedding model to use for testing", + ) @pytest.fixture(scope="session") @@ -105,3 +111,9 @@ def pytest_generate_tests(metafunc): [metafunc.config.getoption("--vision-inference-model")], scope="session", ) + if "embedding_model_id" in metafunc.fixturenames: + metafunc.parametrize( + "embedding_model_id", + [metafunc.config.getoption("--embedding-model")], + scope="session", + ) diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py new file mode 100644 index 000000000..a25382866 --- /dev/null +++ b/tests/client-sdk/inference/test_embedding.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +# +# Test plan: +# +# Types of input: +# - array of a string +# - array of a image (ImageContentItem, either URL or base64 string) +# - array of a text (TextContentItem) +# - array of array of texts, images, or both +# Types of output: +# - list of list of floats +# +# Todo: +# - negative tests +# - empty +# - empty list +# - empty string +# - empty text +# - empty image +# - list of empty texts +# - list of empty images +# - list of empty texts and images +# - long +# - long string +# - long text +# - large image +# - appropriate combinations +# - batch size +# - many inputs +# - invalid +# - invalid URL +# - invalid base64 +# - list of list of strings +# +# Notes: +# - use llama_stack_client fixture +# - use pytest.mark.parametrize when possible +# - no accuracy tests: only check the type of output, not the content +# + +import pytest +from llama_stack_client.types import EmbeddingsResponse +from llama_stack_client.types.shared.interleaved_content import ( + URL, + ImageContentItem, + ImageContentItemImage, + TextContentItem, +) + +DUMMY_STRING = "hello" +DUMMY_STRING2 = "world" +DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") +DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") +# TODO(mf): add a real image URL and base64 string +DUMMY_IMAGE_URL = ImageContentItem( + image=ImageContentItemImage(url=URL(uri="https://example.com/image.jpg")), type="image" +) +DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") + + +@pytest.mark.parametrize( + "contents", + [ + [DUMMY_STRING, DUMMY_STRING2], + [DUMMY_TEXT, DUMMY_TEXT2], + ], + ids=[ + "list[string]", + "list[text]", + ], +) +def test_embedding_text(llama_stack_client, embedding_model_id, contents): + response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) + assert isinstance(response.embeddings[0], list) + assert isinstance(response.embeddings[0][0], float) + + +@pytest.mark.parametrize( + "contents", + [ + [DUMMY_IMAGE_URL, DUMMY_IMAGE_BASE64], + [DUMMY_IMAGE_URL, DUMMY_STRING, DUMMY_IMAGE_BASE64, DUMMY_TEXT], + ], + ids=[ + "list[url,base64]", + "list[url,string,base64,text]", + ], +) +@pytest.mark.skip(reason="Media is not supported") +def test_embedding_image(llama_stack_client, embedding_model_id, contents): + response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) + assert isinstance(response, EmbeddingsResponse) + assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) + assert isinstance(response.embeddings[0], list) + assert isinstance(response.embeddings[0][0], float) From 35ae0e16a1d57ff864fed60e4d71cb5936d187a4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 17:50:24 -0800 Subject: [PATCH 02/45] Fix sqlite_vec config defaults --- .../providers/inline/vector_io/sqlite_vec/config.py | 12 +----------- llama_stack/templates/ollama/run.yaml | 5 +---- 2 files changed, 2 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py index 5a830ff27..e5e3581c6 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -4,26 +4,16 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -# config.py from typing import Any, Dict from pydantic import BaseModel -from llama_stack.providers.utils.kvstore.config import ( - KVStoreConfig, - SqliteKVStoreConfig, -) - class SQLiteVectorIOConfig(BaseModel): db_path: str - kvstore: KVStoreConfig @classmethod def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: return { - "kvstore": SqliteKVStoreConfig.sample_run_config( - __distro_dir__=__distro_dir__, - db_name="sqlite_vec.db", - ) + "db_path": "${env.SQLITE_STORE_DIR:~/.llama/" + __distro_dir__ + "}/" + "sqlite_vec.db", } diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index ab292c5e0..1f45fc228 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -30,10 +30,7 @@ providers: - provider_id: sqlite_vec provider_type: inline::sqlite_vec config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db safety: - provider_id: llama-guard provider_type: inline::llama-guard From 35de4235562b83e3c2149c62c4d4b2e743352f94 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Fri, 21 Feb 2025 00:05:03 -0500 Subject: [PATCH 03/45] docs: Add missing uv command for docs generation in contributing guide (#1197) # What does this PR do? ``` make html /bin/sh: line 1: sphinx-build: command not found make: *** [Makefile:20: html] Error 127 ``` ## Test Plan Tested the command `uv run ./docs/openapi_generator/run_openapi_generator.sh` successfully. --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 83a717273..c5952c8d2 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -145,7 +145,7 @@ If you modify or add new API endpoints, update the API documentation accordingly ```bash $ uv sync --extra dev -$ ./docs/openapi_generator/run_openapi_generator.sh +$ uv run ./docs/openapi_generator/run_openapi_generator.sh ``` The generated API documentation will be available in `docs/_static/`. Make sure to review the changes before committing. From 16e3d99942755b98288fa3b0801e4df1d5053b29 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Fri, 21 Feb 2025 00:05:47 -0500 Subject: [PATCH 04/45] docs: Simplify installation guide with `uv` (#1196) Given that we already switched to uv in other places. We should recommend uv in README's installation guide as well. It's a lot simpler. --- README.md | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index baec8c1bd..3946deea6 100644 --- a/README.md +++ b/README.md @@ -78,18 +78,14 @@ You have two ways to install this repository: ``` * **Install from source**: - If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable). + If you prefer to install from the source code, we recommend using [uv](https://github.com/astral-sh/uv). Then, run the following commands: ```bash - mkdir -p ~/local - cd ~/local git clone git@github.com:meta-llama/llama-stack.git - - conda create -n stack python=3.10 - conda activate stack - cd llama-stack - pip install -e . + + uv sync + uv pip install -e . ``` ### Documentation From 6820718b7168da2aae44570bbcfe1fb9dadd163a Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Fri, 21 Feb 2025 00:18:37 -0500 Subject: [PATCH 05/45] fix: BuiltinTool JSON serialization in remote vLLM provider (#1183) # What does this PR do? The `tool_name` attribute of `ToolDefinition` instances can either be a str or a BuiltinTool enum type. This fixes the remote vLLM provider to use the value of those BuiltinTool enums when serializing to JSON instead of attempting to serialize the actual enum to JSON. Reference of how this is handled in some other areas, since I followed that same pattern for the remote vLLM provider here: - [remote nvidia provider](https://github.com/meta-llama/llama-stack/blob/v0.1.3/llama_stack/providers/remote/inference/nvidia/openai_utils.py#L137-L140) - [meta reference provider](https://github.com/meta-llama/llama-stack/blob/v0.1.3/llama_stack/providers/inline/agents/meta_reference/agent_instance.py#L635-L636) There is opportunity to potentially reconcile the remove nvidia and remote vllm bits where they are both translating Llama Stack Inference APIs to OpenAI client requests, but that's a can of worms I didn't want to open for this bug fix. This explicitly fixes this error when using the remote vLLM provider and the agent tests: ``` TypeError: Object of type BuiltinTool is not JSON serializable ``` So, this is related to #1144 and addresses the immediate issue raised there. With this fix, `tests/client-sdk/agents/test_agents.py::test_builtin_tool_web_search` now gets past the JSON serialization error when using the remote vLLM provider and actually attempts to call the web search tool. I don't have any API keys setup for the actual web search providers yet, so I cannot verify everything works after that point. ## Test Plan I ran the `test_builtin_tool_web_search` locally with the remote vLLM provider like: ``` VLLM_URL="http://localhost:8000/v1" INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" LLAMA_STACK_CONFIG=remote-vllm python -m pytest -v tests/client-sdk/agents/test_agents.py::test_builtin_tool_web_search --inference-model "meta-llama/Llama-3.2-3B-Instruct" ``` Before my change, that reproduced the `TypeError: Object of type BuiltinTool is not JSON serializable` error. After my change, that error is gone and the test actually attempts the web search. That failed for me locally, due to lack of API key, but it gets past the JSON serialization error. Signed-off-by: Ben Browning --- llama_stack/providers/remote/inference/vllm/vllm.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index b37b0d305..d1793c524 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -34,6 +34,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( @@ -112,10 +113,16 @@ def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict] if tool_param.required: compat_required.append(tool_key) + # The tool.tool_name can be a str or a BuiltinTool enum. If + # it's the latter, convert to a string. + tool_name = tool.tool_name + if isinstance(tool_name, BuiltinTool): + tool_name = tool_name.value + compat_tool = { "type": "function", "function": { - "name": tool.tool_name, + "name": tool_name, "description": tool.description, "parameters": { "type": "object", From dd43494847bcb48996de600cf22814d1bc0c63b7 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 21:24:49 -0800 Subject: [PATCH 06/45] Fix inference test fixture --- llama_stack/providers/tests/inference/fixtures.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b553b6b02..5291bffb3 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -20,7 +20,8 @@ from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.groq import GroqConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig -from llama_stack.providers.remote.inference.ollama import DEFAULT_OLLAMA_URL, OllamaImplConfig +from llama_stack.providers.remote.inference.ollama import OllamaImplConfig +from llama_stack.providers.remote.inference.ollama.config import DEFAULT_OLLAMA_URL from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig From 33a64eb5eca5ca17653ed7d455f8aeb335b3f0ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Feb 2025 06:37:37 +0100 Subject: [PATCH 07/45] ci: improve GitHub Actions workflow for website builds (#1151) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Refine the existing update-readthedocs.yml workflow to enhance automation and reliability. Updates include: - Expanding path triggers to cover all documentation files (docs/**) and build artifacts. - Adding steps to set up Python (3.11), install uv, sync dependencies, and build HTML using make html. - Ensuring the ReadTheDocs build trigger only runs on workflow_dispatch events. These improvements help validate website builds in PRs, preventing issues before merging. Signed-off-by: Sébastien Han Signed-off-by: Sébastien Han --- .github/workflows/update-readthedocs.yml | 31 +++++++++++++++++++++--- 1 file changed, 28 insertions(+), 3 deletions(-) diff --git a/.github/workflows/update-readthedocs.yml b/.github/workflows/update-readthedocs.yml index 70369bcc4..23bafa1e5 100644 --- a/.github/workflows/update-readthedocs.yml +++ b/.github/workflows/update-readthedocs.yml @@ -11,17 +11,42 @@ on: branches: - main paths: - - 'docs/source/**' - - 'docs/resources/**' + - 'docs/**' + - '.github/workflows/update-readthedocs.yml' + pull_request: + branches: + - main + paths: + - 'docs/**' - '.github/workflows/update-readthedocs.yml' jobs: update-readthedocs: - runs-on: ubuntu-latest + runs-on: ubuntu-latest env: TOKEN: ${{ secrets.READTHEDOCS_TOKEN }} steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: '3.11' + + - name: Install the latest version of uv + uses: astral-sh/setup-uv@v5 + + - name: Sync with uv + run: uv sync --extra docs + + - name: Build HTML + run: | + cd docs + uv run make html + - name: Trigger ReadTheDocs build + if: github.event_name != 'pull_request' run: | if [ -z "$TOKEN" ]; then echo "READTHEDOCS_TOKEN is not set" From cfa752fc922cdf479699f7c69c66ba778eeec963 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 20 Feb 2025 21:38:35 -0800 Subject: [PATCH 08/45] fix: pass tool_prompt_format to chat_formatter (#1198) Summary: Need this to format the completion message with tool_calls correctly. See added unittest. Test Plan: python -m unittest llama_stack.providers.tests.inference.test_prompt_adapter --- .../tests/inference/test_prompt_adapter.py | 44 +++++++++++++++++++ .../utils/inference/prompt_adapter.py | 8 +++- 2 files changed, 50 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_prompt_adapter.py b/llama_stack/providers/tests/inference/test_prompt_adapter.py index 323c6cb6a..2a6dbb561 100644 --- a/llama_stack/providers/tests/inference/test_prompt_adapter.py +++ b/llama_stack/providers/tests/inference/test_prompt_adapter.py @@ -8,7 +8,10 @@ import unittest from llama_stack.apis.inference import ( ChatCompletionRequest, + CompletionMessage, + StopReason, SystemMessage, + ToolCall, ToolConfig, UserMessage, ) @@ -20,6 +23,7 @@ from llama_stack.models.llama.datatypes import ( ) from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, + chat_completion_request_to_prompt, ) MODEL = "Llama3.1-8B-Instruct" @@ -119,6 +123,46 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): self.assertTrue("Return function calls in JSON format" in messages[1].content) self.assertEqual(messages[-1].content, content) + async def test_completion_message_encoding(self): + request = ChatCompletionRequest( + model=MODEL3_2, + messages=[ + UserMessage(content="hello"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + tool_name="custom1", + arguments={"param1": "value1"}, + call_id="123", + ) + ], + ), + ], + tools=[ + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), + ) + prompt = await chat_completion_request_to_prompt(request, request.model) + self.assertIn('[custom1(param1="value1")]', prompt) + + request.model = MODEL + request.tool_config.tool_prompt_format = ToolPromptFormat.json + prompt = await chat_completion_request_to_prompt(request, request.model) + self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt) + async def test_user_provided_system_message(self): content = "Hello !" system_prompt = "You are a pirate" diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 10fe442e8..ca6fe04fd 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -252,7 +252,9 @@ async def chat_completion_request_to_prompt(request: ChatCompletionRequest, llam request = await convert_request_to_raw(request) formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt(request.messages) + model_input = formatter.encode_dialog_prompt( + request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + ) return formatter.tokenizer.decode(model_input.tokens) @@ -264,7 +266,9 @@ async def chat_completion_request_to_model_input_info( request = await convert_request_to_raw(request) formatter = ChatFormat(tokenizer=Tokenizer.get_instance()) - model_input = formatter.encode_dialog_prompt(request.messages) + model_input = formatter.encode_dialog_prompt( + request.messages, tool_prompt_format=request.tool_config.tool_prompt_format + ) return ( formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens), From 6f9d6223401a9eedfdfe7bc925dc1fb51a495c73 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 21:43:13 -0800 Subject: [PATCH 09/45] fix(api): update embeddings signature so inputs and outputs list align (#1161) See Issue #922 The change is slightly backwards incompatible but no callsite (in our client codebases or stack-apps) every passes a depth-2 `List[List[InterleavedContentItem]]` (which is now disallowed.) ## Test Plan ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ``` Also ran `tests/client-sdk/inference/test_embeddings.py` --- docs/_static/llama-stack-spec.html | 20 ++++++++++---- docs/_static/llama-stack-spec.yaml | 16 +++++++----- llama_stack/apis/inference/inference.py | 6 ++--- llama_stack/distribution/routers/routers.py | 4 +-- .../providers/inline/inference/vllm/vllm.py | 3 ++- .../remote/inference/bedrock/bedrock.py | 4 +-- .../remote/inference/cerebras/cerebras.py | 4 +-- .../remote/inference/databricks/databricks.py | 6 ++--- .../remote/inference/fireworks/fireworks.py | 4 +-- .../providers/remote/inference/groq/groq.py | 3 ++- .../remote/inference/ollama/ollama.py | 3 ++- .../remote/inference/runpod/runpod.py | 5 ++-- .../remote/inference/sambanova/sambanova.py | 26 ++++++++++++++++--- .../providers/remote/inference/tgi/tgi.py | 4 +-- .../remote/inference/together/together.py | 4 +-- .../providers/remote/inference/vllm/vllm.py | 10 +++++-- .../utils/inference/embedding_mixin.py | 4 +-- 17 files changed, 85 insertions(+), 41 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 40c167685..638f7bb7b 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4929,11 +4929,21 @@ "description": "The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint." }, "contents": { - "type": "array", - "items": { - "$ref": "#/components/schemas/InterleavedContent" - }, - "description": "List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text." + "oneOf": [ + { + "type": "array", + "items": { + "type": "string" + } + }, + { + "type": "array", + "items": { + "$ref": "#/components/schemas/InterleavedContentItem" + } + } + ], + "description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c5043665b..08effe7cf 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3224,13 +3224,17 @@ components: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. contents: - type: array - items: - $ref: '#/components/schemas/InterleavedContent' + oneOf: + - type: array + items: + type: string + - type: array + items: + $ref: '#/components/schemas/InterleavedContentItem' description: >- - List of contents to generate embeddings for. Note that content can be - multimodal. The behavior depends on the model and provider. Some models - may only support text. + List of contents to generate embeddings for. Each content can be a string + or an InterleavedContentItem (and hence can be multimodal). The behavior + depends on the model and provider. Some models may only support text. additionalProperties: false required: - model_id diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index a3fb69477..2dfe55977 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -20,7 +20,7 @@ from typing import ( from pydantic import BaseModel, Field, field_validator from typing_extensions import Annotated -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent +from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem from llama_stack.apis.models import Model from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.models.llama.datatypes import ( @@ -481,12 +481,12 @@ class Inference(Protocol): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: """Generate embeddings for content pieces using the specified model. :param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. - :param contents: List of contents to generate embeddings for. Note that content can be multimodal. The behavior depends on the model and provider. Some models may only support text. + :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text. :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} """ ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 016ca4984..d885ebc09 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import URL, InterleavedContent +from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( BenchmarkConfig, @@ -214,7 +214,7 @@ class InferenceRouter(Inference): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: model = await self.routing_table.get_model(model_id) if model is None: diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 5b0df91e7..a6f7a78af 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -23,6 +23,7 @@ from llama_stack.apis.inference import ( CompletionResponseStreamChunk, EmbeddingsResponse, Inference, + InterleavedContentItem, LogProbConfig, Message, ResponseFormat, @@ -230,5 +231,5 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: + async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 9c5a291db..69fb5dea2 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -9,7 +9,7 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -162,7 +162,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embeddings = [] diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 0a27d81d7..71b9155d5 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, @@ -172,6 +172,6 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index de13638f5..e3acd4314 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -130,7 +130,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, - model: str, - contents: List[InterleavedContent], + model_id: str, + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 3f455da3c..95fe65c39 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -232,7 +232,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index c75e92dfe..4b21ae81d 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -19,6 +19,7 @@ from llama_stack.apis.inference import ( EmbeddingsResponse, Inference, InterleavedContent, + InterleavedContentItem, LogProbConfig, Message, ResponseFormat, @@ -140,7 +141,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 1dbcbc294..0071aaa5d 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -13,6 +13,7 @@ from ollama import AsyncClient from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, + InterleavedContentItem, TextContentItem, ) from llama_stack.apis.inference import ( @@ -258,7 +259,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index 09122a8e6..a5acc47f8 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -69,9 +69,10 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = ToolChoice.auto, - tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, + tool_prompt_format: Optional[ToolPromptFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> AsyncGenerator: request = ChatCompletionRequest( model=model, @@ -119,6 +120,6 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, model: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index c05284d7d..b60954abc 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -5,16 +5,36 @@ # the root directory of this source tree. import json -from typing import AsyncGenerator +from typing import AsyncGenerator, List, Optional from openai import OpenAI from llama_stack.apis.common.content_types import ( ImageContentItem, InterleavedContent, + InterleavedContentItem, TextContentItem, ) -from llama_stack.apis.inference import * # noqa: F403 +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionMessage, + EmbeddingsResponse, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + StopReason, + SystemMessage, + ToolCall, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, + ToolResponseMessage, + UserMessage, +) from llama_stack.models.llama.datatypes import ( GreedySamplingStrategy, TopKSamplingStrategy, @@ -119,7 +139,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 1a50e3b61..a52abd20d 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -10,7 +10,7 @@ from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -268,7 +268,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 8afd3e85b..a2c4f1542 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -8,7 +8,7 @@ from typing import AsyncGenerator, List, Optional, Union from together import Together -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -219,7 +219,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) assert all(not content_has_media(content) for content in contents), ( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d1793c524..bff5da8a7 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -10,7 +10,13 @@ from typing import AsyncGenerator, List, Optional, Union from llama_models.datatypes import StopReason, ToolCall from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, + TextDelta, + ToolCallDelta, + ToolCallParseStatus, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -376,7 +382,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index a84c2eecb..947a62f09 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -9,7 +9,7 @@ from typing import List from llama_stack.apis.inference import ( EmbeddingsResponse, - InterleavedContent, + InterleavedContentItem, ModelStore, ) @@ -25,7 +25,7 @@ class SentenceTransformerEmbeddingMixin: async def embeddings( self, model_id: str, - contents: List[InterleavedContent], + contents: List[str] | List[InterleavedContentItem], ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) From 81ce39a607b8f1a286a91e3ee2c6df5079433eb0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 22:27:12 -0800 Subject: [PATCH 10/45] feat(api): Add options for supporting various embedding models (#1192) We need to support: - asymmetric embedding models (#934) - truncation policies (#933) - varying dimensional output (#932) ## Test Plan ```bash $ cd llama_stack/providers/tests/inference $ pytest -s -v -k fireworks test_embeddings.py \ --inference-model nomic-ai/nomic-embed-text-v1.5 --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k together test_embeddings.py \ --inference-model togethercomputer/m2-bert-80M-8k-retrieval --env EMBEDDING_DIMENSION=784 $ pytest -s -v -k ollama test_embeddings.py \ --inference-model all-minilm:latest --env EMBEDDING_DIMENSION=784 ``` --- docs/_static/llama-stack-spec.html | 21 +++++++++++++ docs/_static/llama-stack-spec.yaml | 22 ++++++++++++++ llama_stack/apis/inference/inference.py | 30 +++++++++++++++++++ llama_stack/distribution/routers/routers.py | 14 ++++++++- .../providers/inline/inference/vllm/vllm.py | 11 ++++++- .../remote/inference/bedrock/bedrock.py | 10 ++++++- .../remote/inference/cerebras/cerebras.py | 10 ++++++- .../remote/inference/databricks/databricks.py | 10 ++++++- .../remote/inference/fireworks/fireworks.py | 10 ++++++- .../providers/remote/inference/groq/groq.py | 11 ++++++- .../remote/inference/nvidia/nvidia.py | 11 ++++++- .../remote/inference/ollama/ollama.py | 5 ++++ .../inference/passthrough/passthrough.py | 8 +++++ .../remote/inference/runpod/runpod.py | 3 ++ .../remote/inference/sambanova/sambanova.py | 5 ++++ .../providers/remote/inference/tgi/tgi.py | 10 ++++++- .../remote/inference/together/together.py | 10 ++++++- .../providers/remote/inference/vllm/vllm.py | 5 ++++ .../utils/inference/embedding_mixin.py | 7 ++++- 19 files changed, 202 insertions(+), 11 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 638f7bb7b..fab7c802e 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4944,6 +4944,27 @@ } ], "description": "List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text." + }, + "text_truncation": { + "type": "string", + "enum": [ + "none", + "start", + "end" + ], + "description": "(Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length." + }, + "output_dimension": { + "type": "integer", + "description": "(Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models." + }, + "task_type": { + "type": "string", + "enum": [ + "query", + "document" + ], + "description": "(Optional) How is the embedding being used? This is only supported by asymmetric embedding models." } }, "additionalProperties": false, diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 08effe7cf..fc57bf258 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -3235,6 +3235,28 @@ components: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text. + text_truncation: + type: string + enum: + - none + - start + - end + description: >- + (Optional) Config for how to truncate text for embedding when text is + longer than the model's max sequence length. + output_dimension: + type: integer + description: >- + (Optional) Output dimensionality for the embeddings. Only supported by + Matryoshka models. + task_type: + type: string + enum: + - query + - document + description: >- + (Optional) How is the embedding being used? This is only supported by + asymmetric embedding models. additionalProperties: false required: - model_id diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 2dfe55977..d83506dd4 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -402,6 +402,30 @@ class ModelStore(Protocol): def get_model(self, identifier: str) -> Model: ... +class TextTruncation(Enum): + """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. + + :cvar none: No truncation (default). If the text is longer than the model's max sequence length, you will get an error. + :cvar start: Truncate from the start + :cvar end: Truncate from the end + """ + + none = "none" + start = "start" + end = "end" + + +class EmbeddingTaskType(Enum): + """How is the embedding being used? This is only supported by asymmetric embedding models. + + :cvar query: Used for a query for semantic search. + :cvar document: Used at indexing time when ingesting documents. + """ + + query = "query" + document = "document" + + @runtime_checkable @trace_protocol class Inference(Protocol): @@ -482,11 +506,17 @@ class Inference(Protocol): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: """Generate embeddings for content pieces using the specified model. :param model_id: The identifier of the model to use. The model must be an embedding model registered with Llama Stack and available via the /models endpoint. :param contents: List of contents to generate embeddings for. Each content can be a string or an InterleavedContentItem (and hence can be multimodal). The behavior depends on the model and provider. Some models may only support text. + :param output_dimension: (Optional) Output dimensionality for the embeddings. Only supported by Matryoshka models. + :param text_truncation: (Optional) Config for how to truncate text for embedding when text is longer than the model's max sequence length. + :param task_type: (Optional) How is the embedding being used? This is only supported by asymmetric embedding models. :returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id} """ ... diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index d885ebc09..df4ed03d3 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,11 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import URL, InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + URL, + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( BenchmarkConfig, @@ -17,11 +21,13 @@ from llama_stack.apis.eval import ( ) from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -215,6 +221,9 @@ class InferenceRouter(Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.routing_table.get_model(model_id) if model is None: @@ -224,6 +233,9 @@ class InferenceRouter(Inference): return await self.routing_table.get_provider_impl(model_id).embeddings( model_id=model_id, contents=contents, + text_truncation=text_truncation, + output_dimension=output_dimension, + task_type=task_type, ) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index a6f7a78af..d03ea933a 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -22,12 +22,14 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, InterleavedContentItem, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -231,5 +233,12 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async for chunk in process_chat_completion_stream_response(stream, request): yield chunk - async def embeddings(self, model_id: str, contents: List[str] | List[InterleavedContentItem]) -> EmbeddingsResponse: + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 69fb5dea2..b82a4c752 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -9,17 +9,22 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from botocore.client import BaseClient -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -163,6 +168,9 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embeddings = [] diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 71b9155d5..4deeea630 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -8,17 +8,22 @@ from typing import AsyncGenerator, List, Optional, Union from cerebras.cloud.sdk import AsyncCerebras -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionRequest, CompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -173,5 +178,8 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index e3acd4314..75751c8b1 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -8,16 +8,21 @@ from typing import AsyncGenerator, List, Optional from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolDefinition, ToolPromptFormat, @@ -132,5 +137,8 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index 95fe65c39..b9b23584b 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -8,19 +8,24 @@ from typing import AsyncGenerator, List, Optional, Union from fireworks.client import Fireworks -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, CompletionResponse, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -233,6 +238,9 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4b21ae81d..45c15a467 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -17,17 +17,23 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, InterleavedContent, InterleavedContentItem, LogProbConfig, Message, ResponseFormat, + TextTruncation, ToolChoice, ToolConfig, ) from llama_stack.distribution.request_headers import NeedsRequestProviderData -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import ( + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.models.llama.sku_list import CoreModelId from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( @@ -142,6 +148,9 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 6f38230b2..ecd53e91c 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -23,14 +23,20 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, + TextTruncation, ToolChoice, ToolConfig, ) -from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat +from llama_stack.models.llama.datatypes import ( + SamplingParams, + ToolDefinition, + ToolPromptFormat, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -122,6 +128,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: if any(content_has_media(content) for content in contents): raise NotImplementedError("Media is not supported") diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 0071aaa5d..62c8381a8 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -21,11 +21,13 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -260,6 +262,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/remote/inference/passthrough/passthrough.py b/llama_stack/providers/remote/inference/passthrough/passthrough.py index a34c34f69..11da6bb9e 100644 --- a/llama_stack/providers/remote/inference/passthrough/passthrough.py +++ b/llama_stack/providers/remote/inference/passthrough/passthrough.py @@ -11,11 +11,13 @@ from llama_stack_client import LlamaStackClient from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -138,6 +140,9 @@ class PassthroughInferenceAdapter(Inference): self, model_id: str, contents: List[InterleavedContent], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: client = self._get_client() model = await self.model_store.get_model(model_id) @@ -145,4 +150,7 @@ class PassthroughInferenceAdapter(Inference): return client.inference.embeddings( model_id=model.provider_resource_id, contents=contents, + text_truncation=text_truncation, + output_dimension=output_dimension, + task_type=task_type, ) diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a5acc47f8..bd620aa64 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -121,5 +121,8 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): self, model: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index b60954abc..57a296258 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -20,6 +20,7 @@ from llama_stack.apis.inference import ( ChatCompletionResponse, CompletionMessage, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, @@ -27,6 +28,7 @@ from llama_stack.apis.inference import ( SamplingParams, StopReason, SystemMessage, + TextTruncation, ToolCall, ToolChoice, ToolConfig, @@ -140,6 +142,9 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index a52abd20d..d09ca241f 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -10,18 +10,23 @@ from typing import AsyncGenerator, List, Optional from huggingface_hub import AsyncInferenceClient, HfApi -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -269,6 +274,9 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: raise NotImplementedError() diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a2c4f1542..1fca54bb3 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -8,18 +8,23 @@ from typing import AsyncGenerator, List, Optional, Union from together import Together -from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.common.content_types import ( + InterleavedContent, + InterleavedContentItem, +) from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, CompletionRequest, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -220,6 +225,9 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) assert all(not content_has_media(content) for content in contents), ( diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index bff5da8a7..b9422d85d 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -28,12 +28,14 @@ from llama_stack.apis.inference import ( CompletionResponse, CompletionResponseStreamChunk, EmbeddingsResponse, + EmbeddingTaskType, Inference, LogProbConfig, Message, ResponseFormat, ResponseFormatType, SamplingParams, + TextTruncation, ToolChoice, ToolConfig, ToolDefinition, @@ -383,6 +385,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 947a62f09..32aa5da3f 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -5,12 +5,14 @@ # the root directory of this source tree. import logging -from typing import List +from typing import List, Optional from llama_stack.apis.inference import ( EmbeddingsResponse, + EmbeddingTaskType, InterleavedContentItem, ModelStore, + TextTruncation, ) EMBEDDING_MODELS = {} @@ -26,6 +28,9 @@ class SentenceTransformerEmbeddingMixin: self, model_id: str, contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) From 36b762303cd124febe0a6a43389891767c8072e5 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 22:46:17 -0800 Subject: [PATCH 11/45] Fix client-sdk inference text -- spurious parameterization of test_case --- tests/client-sdk/inference/test_text_inference.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 1fe53ab86..ac7481507 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -250,9 +250,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming( assert tool_invocation_content == "[get_weather, {'location': 'San Francisco, CA'}]" -@pytest.mark.parametrize("test_case", ["chat_completion-01"]) def test_text_chat_completion_with_tool_choice_required( - llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format, inference_provider_type + llama_stack_client, + text_model_id, + get_weather_tool_definition, + provider_tool_format, + inference_provider_type, ): response = llama_stack_client.inference.chat_completion( model_id=text_model_id, @@ -261,7 +264,10 @@ def test_text_chat_completion_with_tool_choice_required( {"role": "user", "content": "What's the weather like in San Francisco?"}, ], tools=[get_weather_tool_definition], - tool_config={"tool_choice": "required", "tool_prompt_format": provider_tool_format}, + tool_config={ + "tool_choice": "required", + "tool_prompt_format": provider_tool_format, + }, stream=True, ) tool_invocation_content = extract_tool_invocation_content(response) From 34226d6c935a4cf68ba62a6973c8aea2a6c89d16 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 20 Feb 2025 23:10:33 -0800 Subject: [PATCH 12/45] Another test_case related breakage fix --- tests/client-sdk/inference/test_text_inference.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index ac7481507..545325bbe 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -291,6 +291,7 @@ def test_text_chat_completion_with_tool_choice_none( assert tool_invocation_content == "" +@pytest.mark.parametrize("test_case", ["chat_completion-01"]) def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case): class AnswerFormat(BaseModel): first_name: str From 3099c5243fb4d93cc9df4282395371ec97660812 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 21 Feb 2025 10:02:21 -0600 Subject: [PATCH 13/45] fix: update URL import, URL -> ImageContentItemImageURL (#1204) # What does this PR do? fixes test to use new name for URL import ## Test Plan `LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_embedding.py --embedding-model baai/bge-m3` --- tests/client-sdk/inference/test_embedding.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py index a25382866..602f9c062 100644 --- a/tests/client-sdk/inference/test_embedding.py +++ b/tests/client-sdk/inference/test_embedding.py @@ -47,9 +47,9 @@ import pytest from llama_stack_client.types import EmbeddingsResponse from llama_stack_client.types.shared.interleaved_content import ( - URL, ImageContentItem, ImageContentItemImage, + ImageContentItemImageURL, TextContentItem, ) @@ -59,7 +59,7 @@ DUMMY_TEXT = TextContentItem(text=DUMMY_STRING, type="text") DUMMY_TEXT2 = TextContentItem(text=DUMMY_STRING2, type="text") # TODO(mf): add a real image URL and base64 string DUMMY_IMAGE_URL = ImageContentItem( - image=ImageContentItemImage(url=URL(uri="https://example.com/image.jpg")), type="image" + image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image" ) DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") From c9c4a3c92129da56f1511f3f3575612705928e2c Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Sat, 22 Feb 2025 00:05:12 +0800 Subject: [PATCH 14/45] feat: model remove cmd (#1128) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] add a subcommand, help to clean the unneeded model: ``` $ llama model --help usage: llama model [-h] {download,list,prompt-format,describe,verify-download,remove} ... Work with llama models options: -h, --help show this help message and exit $ llama model remove --help usage: llama model remove [-h] -m MODEL [-f] Remove the downloaded llama model options: -h, --help show this help message and exit -m MODEL, --model MODEL Specify the llama downloaded model name -f, --force Used to forcefully remove the llama model from the storage without further confirmation $ llama model remove -m Llama3.2-1B-Instruct:int4-qlora-eo8 Are you sure you want to remove Llama3.2-1B-Instruct:int4-qlora-eo8? (y/n): n Removal aborted. $ llama model remove -mLlama3.2-1B-Instruct:int4-qlora-eo8-f Llama3.2-1B-Instruct:int4-qlora-eo8 removed. ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) --------- Signed-off-by: reidliu Co-authored-by: reidliu --- .../references/llama_cli_reference/index.md | 15 +++-- llama_stack/cli/model/model.py | 2 + llama_stack/cli/model/remove.py | 67 +++++++++++++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) create mode 100644 llama_stack/cli/model/remove.py diff --git a/docs/source/references/llama_cli_reference/index.md b/docs/source/references/llama_cli_reference/index.md index 76abce544..a43666963 100644 --- a/docs/source/references/llama_cli_reference/index.md +++ b/docs/source/references/llama_cli_reference/index.md @@ -171,7 +171,7 @@ The `llama model` command helps you explore the model’s interface. llama model --help ``` ``` -usage: llama model [-h] {download,list,prompt-format,describe} ... +usage: llama model [-h] {download,list,prompt-format,describe,verify-download,remove} ... Work with llama models @@ -179,15 +179,15 @@ options: -h, --help show this help message and exit model_subcommands: - {download,list,prompt-format,describe} + {download,list,prompt-format,describe,verify-download,remove} ``` +### Describe + You can use the describe command to know more about a model: ``` llama model describe -m Llama3.2-3B-Instruct ``` -### Describe - ``` +-----------------------------+----------------------------------+ | Model | Llama3.2-3B-Instruct | @@ -234,3 +234,10 @@ llama model prompt-format -m Llama3.2-3B-Instruct You will be shown a Markdown formatted description of the model interface and how prompts / messages are formatted for various scenarios. **NOTE**: Outputs in terminal are color printed to show special tokens. + +### Remove model +You can run `llama model remove` to remove unecessary model: + +``` +llama model remove -m Llama-Guard-3-8B-int8 +``` diff --git a/llama_stack/cli/model/model.py b/llama_stack/cli/model/model.py index 3f8f55773..2f4065b83 100644 --- a/llama_stack/cli/model/model.py +++ b/llama_stack/cli/model/model.py @@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe from llama_stack.cli.model.download import ModelDownload from llama_stack.cli.model.list import ModelList from llama_stack.cli.model.prompt_format import ModelPromptFormat +from llama_stack.cli.model.remove import ModelRemove from llama_stack.cli.model.verify_download import ModelVerifyDownload from llama_stack.cli.subcommand import Subcommand @@ -35,3 +36,4 @@ class ModelParser(Subcommand): ModelPromptFormat.create(subparsers) ModelDescribe.create(subparsers) ModelVerifyDownload.create(subparsers) + ModelRemove.create(subparsers) diff --git a/llama_stack/cli/model/remove.py b/llama_stack/cli/model/remove.py new file mode 100644 index 000000000..ee8d6299d --- /dev/null +++ b/llama_stack/cli/model/remove.py @@ -0,0 +1,67 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import argparse +import os +import shutil + +from llama_stack.cli.subcommand import Subcommand +from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR +from llama_stack.models.llama.sku_list import resolve_model + + +class ModelRemove(Subcommand): + """Remove the downloaded llama model""" + + def __init__(self, subparsers: argparse._SubParsersAction): + super().__init__() + self.parser = subparsers.add_parser( + "remove", + prog="llama model remove", + description="Remove the downloaded llama model", + formatter_class=argparse.RawTextHelpFormatter, + ) + self._add_arguments() + self.parser.set_defaults(func=self._run_model_remove_cmd) + + def _add_arguments(self): + self.parser.add_argument( + "-m", + "--model", + required=True, + help="Specify the llama downloaded model name, see `llama model list --downloaded`", + ) + self.parser.add_argument( + "-f", + "--force", + action="store_true", + help="Used to forcefully remove the llama model from the storage without further confirmation", + ) + + def _run_model_remove_cmd(self, args: argparse.Namespace) -> None: + from .safety_models import prompt_guard_model_sku + + prompt_guard = prompt_guard_model_sku() + if args.model == prompt_guard.model_id: + model = prompt_guard + else: + model = resolve_model(args.model) + + model_path = os.path.join(DEFAULT_CHECKPOINT_DIR, args.model.replace(":", "-")) + + if model is None or not os.path.isdir(model_path): + print(f"'{args.model}' is not a valid llama model or does not exist.") + return + + if args.force: + shutil.rmtree(model_path) + print(f"{args.model} removed.") + else: + if input(f"Are you sure you want to remove {args.model}? (y/n): ").strip().lower() == "y": + shutil.rmtree(model_path) + print(f"{args.model} removed.") + else: + print("Removal aborted.") From d2701b0d6a57d0a35fc64643400636e29ce802ee Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Sat, 22 Feb 2025 00:06:25 +0800 Subject: [PATCH 15/45] chore: remove configure subcommand (#1202) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] When tried to use `configure`, and found it `DEPRECATED`, and found pr https://github.com/meta-llama/llama-stack/pull/371 to remove it, not sure why not remove the `configure.py`? ``` $ llama stack configure /tmp/test.yaml usage: llama stack configure [-h] [--output-dir OUTPUT_DIR] config llama stack configure: error: DEPRECATED! llama stack configure has been deprecated. Please use llama stack run instead. Please see example run.yaml in /distributions folder. ``` It would better better to tell when user check it how to use with `--help` first: ``` before: $ llama stack configure --help usage: llama stack configure [-h] [--output-dir OUTPUT_DIR] config Configure a llama stack distribution positional arguments: after: $ llama stack configure --help usage: llama stack configure [-h] [--output-dir OUTPUT_DIR] config Configure a llama stack distribution DEPRECATED! llama stack configure has been deprecated. Please use llama stack run instead. Please see example run.yaml in /distributions folder. ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) --------- Signed-off-by: reidliu Co-authored-by: reidliu --- llama_stack/cli/stack/configure.py | 46 ------------------------------ llama_stack/cli/stack/stack.py | 2 -- 2 files changed, 48 deletions(-) delete mode 100644 llama_stack/cli/stack/configure.py diff --git a/llama_stack/cli/stack/configure.py b/llama_stack/cli/stack/configure.py deleted file mode 100644 index 2bb3f7313..000000000 --- a/llama_stack/cli/stack/configure.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import argparse - -from llama_stack.cli.subcommand import Subcommand - - -class StackConfigure(Subcommand): - """Llama cli for configuring llama toolchain configs""" - - def __init__(self, subparsers: argparse._SubParsersAction): - super().__init__() - self.parser = subparsers.add_parser( - "configure", - prog="llama stack configure", - description="Configure a llama stack distribution", - formatter_class=argparse.RawTextHelpFormatter, - ) - self._add_arguments() - self.parser.set_defaults(func=self._run_stack_configure_cmd) - - def _add_arguments(self): - self.parser.add_argument( - "config", - type=str, - help="Path to the build config file (e.g. ~/.llama/builds//-build.yaml). For container, this could also be the name of the container image. ", - ) - - self.parser.add_argument( - "--output-dir", - type=str, - help="Path to the output directory to store generated run.yaml config file. If not specified, will use ~/.llama/build//-run.yaml", - ) - - def _run_stack_configure_cmd(self, args: argparse.Namespace) -> None: - self.parser.error( - """ - DEPRECATED! llama stack configure has been deprecated. - Please use llama stack run instead. - Please see example run.yaml in /distributions folder. - """ - ) diff --git a/llama_stack/cli/stack/stack.py b/llama_stack/cli/stack/stack.py index 10e49f8c9..431f7b98e 100644 --- a/llama_stack/cli/stack/stack.py +++ b/llama_stack/cli/stack/stack.py @@ -10,7 +10,6 @@ from importlib.metadata import version from llama_stack.cli.subcommand import Subcommand from .build import StackBuild -from .configure import StackConfigure from .list_apis import StackListApis from .list_providers import StackListProviders from .run import StackRun @@ -37,7 +36,6 @@ class StackParser(Subcommand): # Add sub-commands StackBuild.create(subparsers) - StackConfigure.create(subparsers) StackListApis.create(subparsers) StackListProviders.create(subparsers) StackRun.create(subparsers) From 46da187c0729122993b23aa69b9f697b2f7c525b Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Fri, 21 Feb 2025 10:07:35 -0600 Subject: [PATCH 16/45] fix: remove list of list tests, no longer relevant after #1161 (#1205) # What does this PR do? #1161 updated the embedding signature making the nested list tests irrelevant --- tests/client-sdk/inference/test_embedding.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/client-sdk/inference/test_embedding.py b/tests/client-sdk/inference/test_embedding.py index 602f9c062..3304406a9 100644 --- a/tests/client-sdk/inference/test_embedding.py +++ b/tests/client-sdk/inference/test_embedding.py @@ -12,7 +12,6 @@ # - array of a string # - array of a image (ImageContentItem, either URL or base64 string) # - array of a text (TextContentItem) -# - array of array of texts, images, or both # Types of output: # - list of list of floats # @@ -23,9 +22,6 @@ # - empty string # - empty text # - empty image -# - list of empty texts -# - list of empty images -# - list of empty texts and images # - long # - long string # - long text @@ -36,7 +32,6 @@ # - invalid # - invalid URL # - invalid base64 -# - list of list of strings # # Notes: # - use llama_stack_client fixture From da9f0b786932f7c6995f7c65781da44ec7a25605 Mon Sep 17 00:00:00 2001 From: Rashmi Pawar <168514198+raspawar@users.noreply.github.com> Date: Fri, 21 Feb 2025 21:39:17 +0530 Subject: [PATCH 17/45] test(client-sdk): Update embedding test types to use latest imports (#1203) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Updates ImageContentItemImageURL import - fixes `embedding_dimensions` metadata param ## Test Plan - Ran pytest locally, verified embedding tests pass with new types ![Screenshot 2025-02-21 at 6 54 27 PM](https://github.com/user-attachments/assets/f80e3785-04c3-415e-9276-88aa8136bf00) cc: @dglogo @sumitb --- llama_stack/providers/remote/inference/nvidia/models.py | 2 +- llama_stack/templates/nvidia/run.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/models.py b/llama_stack/providers/remote/inference/nvidia/models.py index fa9944be1..4305f4c6f 100644 --- a/llama_stack/providers/remote/inference/nvidia/models.py +++ b/llama_stack/providers/remote/inference/nvidia/models.py @@ -52,7 +52,7 @@ _MODEL_ENTRIES = [ provider_model_id="baai/bge-m3", model_type=ModelType.embedding, metadata={ - "embedding_dimensions": 1024, + "embedding_dimension": 1024, "context_length": 8192, }, ), diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index 891fd112a..4c38ec24e 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -136,7 +136,7 @@ models: provider_model_id: meta/llama-3.2-90b-vision-instruct model_type: llm - metadata: - embedding_dimensions: 1024 + embedding_dimension: 1024 context_length: 8192 model_id: baai/bge-m3 provider_id: nvidia From 9898589f12d6faa31eac004828e9c8bda364ceb2 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Sat, 22 Feb 2025 00:10:34 +0800 Subject: [PATCH 18/45] fix: convert back to model descriptor for model in list --downloaded (#1201) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] Currently , `model` in `--downloaded` just use the directory(already replace `:`), so covert back to descriptor keep the same with ` llama model list`, and remove command also use `descriptor`. ``` before: $ llama model list --downloaded +-------------------------------------+----------+---------------------+ | Model | Size | Modified Time | +-------------------------------------+----------+---------------------+ | Llama3.2-1B-Instruct-int4-qlora-eo8 | 1.53 GB | 2025-02-20 16:32:49 | +-------------------------------------+----------+---------------------+ after: $ llama model list --downloaded +-------------------------------------+----------+---------------------+ | Model | Size | Modified Time | +-------------------------------------+----------+---------------------+ | Llama3.2-1B-Instruct:int4-qlora-eo8 | 1.53 GB | 2025-02-20 16:32:49 | +-------------------------------------+----------+---------------------+ ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: reidliu Co-authored-by: reidliu --- llama_stack/cli/model/list.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/llama_stack/cli/model/list.py b/llama_stack/cli/model/list.py index 2f62cb9ce..622a6b4e7 100644 --- a/llama_stack/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -19,6 +19,13 @@ def _get_model_size(model_dir): return sum(f.stat().st_size for f in Path(model_dir).rglob("*") if f.is_file()) +def _convert_to_model_descriptor(model): + for m in all_registered_models(): + if model == m.descriptor().replace(":", "-"): + return str(m.descriptor()) + return str(model) + + def _run_model_list_downloaded_cmd() -> None: headers = ["Model", "Size", "Modified Time"] @@ -30,7 +37,7 @@ def _run_model_list_downloaded_cmd() -> None: modified_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(os.path.getmtime(abs_path))) rows.append( [ - model, + _convert_to_model_descriptor(model), model_size, modified_time, ] From 6634864b196da80c98a324a07cc35e288022107c Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Fri, 21 Feb 2025 11:29:32 -0500 Subject: [PATCH 19/45] docs: Add missing uv command and clarify website rebuild (#1199) # What does this PR do? This fixes the following error: ``` $ make html /bin/sh: line 1: sphinx-build: command not found make: *** [Makefile:20: html] Error 127 ``` Also clarifies that this command only rebuilds the website without watching/refreshes. ## Test Plan New command works. --------- Signed-off-by: Yuan Tang --- CONTRIBUTING.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c5952c8d2..1e4a88f13 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -134,9 +134,11 @@ If you are making changes to the documentation at [https://llama-stack.readthedo $ cd llama-stack/docs $ uv sync --extra docs +# This rebuilds the documentation pages. +$ uv run make html + # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. -$ make html -$ uv run sphinx-autobuild source build/html +$ uv run sphinx-autobuild source build/html --write-all ``` ### Update API Documentation From 840fae22593047ab3d805b97e61bf02ba6b4339a Mon Sep 17 00:00:00 2001 From: Jamie Land <38305141+jland-redhat@users.noreply.github.com> Date: Fri, 21 Feb 2025 11:32:56 -0500 Subject: [PATCH 20/45] fix: Updating images so that they are able to run without root access (#1208) # What does this PR do? Addresses issues where the container is unable to run as root. Gives write access to required folders. [//]: # (If resolving an issue, uncomment and update the line below) (Closes #[1207]) ## Test Plan I built locally and ran `llama stack build --template remote-vllm --image-type container` and validated I could see my changes in the output: ``` #11 1.186 Installed 11 packages in 61ms #11 1.186 + llama-models==0.1.3 #11 1.186 + llama-stack==0.1.3 #11 1.186 + llama-stack-client==0.1.3 #11 1.186 + markdown-it-py==3.0.0 #11 1.186 + mdurl==0.1.2 #11 1.186 + prompt-toolkit==3.0.50 #11 1.186 + pyaml==25.1.0 #11 1.186 + pygments==2.19.1 #11 1.186 + rich==13.9.4 #11 1.186 + tiktoken==0.9.0 #11 1.186 + wcwidth==0.2.13 #11 DONE 1.6s #12 [ 9/10] RUN mkdir -p /.llama /.cache #12 DONE 0.3s #13 [10/10] RUN chmod -R g+rw /app /.llama /.cache #13 DONE 0.3s #14 exporting to image #14 exporting layers #14 exporting layers 3.7s done #14 writing image sha256:11cc8bd954db6d036037bcaf471b173ddd5261ac4b1e72074cccf85d18aefb96 done #14 naming to docker.io/library/distribution-remote-vllm:0.1.3 done #14 DONE 3.7s + set +x Success! ``` This is what the resulting image looks like: ![image](https://github.com/user-attachments/assets/070b9c05-b40f-4e7e-aa24-fef260c395e3) Also tagged the image as `0.1.3-test` and [pushed to quay](https://quay.io/repository/jland/distribution-remote-vllm?tab=tags) (note there are a bunch of critical vulnerabilities we may want to look into) And for good measure I deployed the resulting image on my Openshift environment using the default Security Context and validated that there were no issue with it coming up. My validation was all done with the `vllm-remote` distribution, but if I am understanding everything correctly the other distributions are just different run.yaml configs. [//]: # (## Documentation) Please let me know if there is anything else I need to do. Co-authored-by: Jamie Land --- llama_stack/distribution/build_container.sh | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 4101cec44..7c6d758c0 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -177,6 +177,15 @@ ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"] EOF fi +# Add other require item commands genearic to all containers +add_to_container << EOF + +# Allows running as non-root user +RUN mkdir -p /.llama /.cache + +RUN chmod -R g+rw /app /.llama /.cache +EOF + printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n" cat $TEMP_DIR/Containerfile printf "\n" From 11697f85c51d7cda3fb613db3a553a1c549281e8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 10:35:56 -0800 Subject: [PATCH 21/45] fix: pull ollama embedding model if necessary (#1209) Embedding models are tiny and can be pulled on-demand. Let's do that so the user doesn't have to do "yet another thing" to get themselves set up. Thanks @hardikjshah for the suggestion. Also fixed a build dependency miss (TODO: distro_codegen needs to actually check that the build template contains all providers mentioned for the run.yaml file) ## Test Plan First run `ollama rm all-minilm:latest`. Run `llama stack build --template ollama && llama stack run ollama --env INFERENCE_MODEL=llama3.2:3b-instruct-fp16`. See that it outputs a "Pulling embedding model `all-minilm:latest`" output and the stack starts up correctly. Verify that `ollama list` shows the model is correctly downloaded. --- distributions/dependencies.json | 1 + docs/source/distributions/self_hosted_distro/ollama.md | 2 +- llama_stack/providers/remote/inference/ollama/ollama.py | 2 ++ llama_stack/templates/ollama/build.yaml | 1 + llama_stack/templates/ollama/ollama.py | 2 +- 5 files changed, 6 insertions(+), 2 deletions(-) diff --git a/distributions/dependencies.json b/distributions/dependencies.json index 345a29f33..df63c0773 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -356,6 +356,7 @@ "scikit-learn", "scipy", "sentencepiece", + "sqlite-vec", "tqdm", "transformers", "uvicorn", diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index 2fa796e81..b800b4a43 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` | -| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | +| vector_io | `inline::faiss`, `inline::sqlite_vec`, `remote::chromadb`, `remote::pgvector` | You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration. diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 62c8381a8..f61ac9898 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -281,6 +281,8 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): async def register_model(self, model: Model) -> Model: if model.model_type == ModelType.embedding: + log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") + await self.client.pull(model.provider_resource_id) response = await self.client.list() else: response = await self.client.ps() diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 0fee6808c..48960c5ba 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -6,6 +6,7 @@ distribution_spec: - remote::ollama vector_io: - inline::faiss + - inline::sqlite_vec - remote::chromadb - remote::pgvector safety: diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index 31119e040..2b135c008 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::ollama"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], + "vector_io": ["inline::faiss", "inline::sqlite_vec", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], From 992f865b2e416be896cc298ebe0ed710312b663e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 11:33:41 -0800 Subject: [PATCH 22/45] chore: move embedding deps to RAG tool where they are needed (#1210) `EMBEDDING_DEPS` were wrongly associated with `vector_io` providers. They are needed by https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/utils/memory/vector_store.py#L142 and related code and is used by the RAG tool and as such should only be needed by the `inline::rag-runtime` provider. --- distributions/dependencies.json | 29 +++--------- .../self_hosted_distro/cerebras.md | 2 +- .../distributions/self_hosted_distro/dell.md | 2 +- .../self_hosted_distro/fireworks.md | 2 +- .../self_hosted_distro/ollama.md | 2 +- .../self_hosted_distro/remote-vllm.md | 2 +- .../distributions/self_hosted_distro/tgi.md | 2 +- .../self_hosted_distro/together.md | 2 +- llama_stack/cli/stack/run.py | 7 ++- .../sentence_transformers.py | 1 - llama_stack/providers/registry/inference.py | 5 +- .../providers/registry/tool_runtime.py | 13 ++++- llama_stack/providers/registry/vector_io.py | 47 +++++++------------ .../providers/tests/vector_io/fixtures.py | 2 +- llama_stack/templates/cerebras/build.yaml | 1 + llama_stack/templates/cerebras/cerebras.py | 2 +- llama_stack/templates/dell/build.yaml | 1 + llama_stack/templates/dell/dell.py | 2 +- llama_stack/templates/fireworks/build.yaml | 1 + llama_stack/templates/fireworks/fireworks.py | 2 +- .../templates/hf-serverless/build.yaml | 1 + .../templates/hf-serverless/hf_serverless.py | 2 +- llama_stack/templates/ollama/build.yaml | 3 +- llama_stack/templates/ollama/ollama.py | 33 ++++--------- .../templates/ollama/run-with-safety.yaml | 19 ++------ llama_stack/templates/ollama/run.yaml | 20 +------- llama_stack/templates/remote-vllm/build.yaml | 1 + llama_stack/templates/remote-vllm/vllm.py | 2 +- llama_stack/templates/tgi/build.yaml | 1 + llama_stack/templates/tgi/tgi.py | 2 +- llama_stack/templates/together/build.yaml | 1 + llama_stack/templates/together/together.py | 2 +- llama_stack/templates/vllm-gpu/build.yaml | 1 + llama_stack/templates/vllm-gpu/vllm.py | 2 +- 34 files changed, 85 insertions(+), 132 deletions(-) diff --git a/distributions/dependencies.json b/distributions/dependencies.json index df63c0773..9e468f08d 100644 --- a/distributions/dependencies.json +++ b/distributions/dependencies.json @@ -30,9 +30,7 @@ "sentencepiece", "tqdm", "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + "uvicorn" ], "cerebras": [ "aiosqlite", @@ -170,9 +168,7 @@ "sentencepiece", "tqdm", "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + "uvicorn" ], "hf-serverless": [ "aiohttp", @@ -247,9 +243,7 @@ "tqdm", "transformers", "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + "zmq" ], "meta-reference-quantized-gpu": [ "accelerate", @@ -290,9 +284,7 @@ "tqdm", "transformers", "uvicorn", - "zmq", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + "zmq" ], "nvidia": [ "aiosqlite", @@ -323,9 +315,7 @@ "sentencepiece", "tqdm", "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + "uvicorn" ], "ollama": [ "aiohttp", @@ -335,7 +325,6 @@ "chardet", "chromadb-client", "datasets", - "faiss-cpu", "fastapi", "fire", "httpx", @@ -359,9 +348,7 @@ "sqlite-vec", "tqdm", "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + "uvicorn" ], "remote-vllm": [ "aiosqlite", @@ -424,9 +411,7 @@ "sentencepiece", "tqdm", "transformers", - "uvicorn", - "sentence-transformers --no-deps", - "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + "uvicorn" ], "tgi": [ "aiohttp", diff --git a/docs/source/distributions/self_hosted_distro/cerebras.md b/docs/source/distributions/self_hosted_distro/cerebras.md index a0c9eb263..6e2af14fd 100644 --- a/docs/source/distributions/self_hosted_distro/cerebras.md +++ b/docs/source/distributions/self_hosted_distro/cerebras.md @@ -8,7 +8,7 @@ The `llamastack/distribution-cerebras` distribution consists of the following pr | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | -| inference | `remote::cerebras` | +| inference | `remote::cerebras`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/docs/source/distributions/self_hosted_distro/dell.md b/docs/source/distributions/self_hosted_distro/dell.md index aef3ecf58..f49b332a9 100644 --- a/docs/source/distributions/self_hosted_distro/dell.md +++ b/docs/source/distributions/self_hosted_distro/dell.md @@ -19,7 +19,7 @@ The `llamastack/distribution-dell` distribution consists of the following provid | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | -| inference | `remote::tgi` | +| inference | `remote::tgi`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/docs/source/distributions/self_hosted_distro/fireworks.md b/docs/source/distributions/self_hosted_distro/fireworks.md index 7951e148e..f69e6d963 100644 --- a/docs/source/distributions/self_hosted_distro/fireworks.md +++ b/docs/source/distributions/self_hosted_distro/fireworks.md @@ -18,7 +18,7 @@ The `llamastack/distribution-fireworks` distribution consists of the following p | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | -| inference | `remote::fireworks` | +| inference | `remote::fireworks`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/docs/source/distributions/self_hosted_distro/ollama.md b/docs/source/distributions/self_hosted_distro/ollama.md index b800b4a43..a487109c8 100644 --- a/docs/source/distributions/self_hosted_distro/ollama.md +++ b/docs/source/distributions/self_hosted_distro/ollama.md @@ -23,7 +23,7 @@ The `llamastack/distribution-ollama` distribution consists of the following prov | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | | tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime` | -| vector_io | `inline::faiss`, `inline::sqlite_vec`, `remote::chromadb`, `remote::pgvector` | +| vector_io | `inline::sqlite-vec`, `remote::chromadb`, `remote::pgvector` | You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration. diff --git a/docs/source/distributions/self_hosted_distro/remote-vllm.md b/docs/source/distributions/self_hosted_distro/remote-vllm.md index 6c3bbd1d0..01f38807b 100644 --- a/docs/source/distributions/self_hosted_distro/remote-vllm.md +++ b/docs/source/distributions/self_hosted_distro/remote-vllm.md @@ -17,7 +17,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | -| inference | `remote::vllm` | +| inference | `remote::vllm`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/docs/source/distributions/self_hosted_distro/tgi.md b/docs/source/distributions/self_hosted_distro/tgi.md index f4eecf2cd..80baf9c81 100644 --- a/docs/source/distributions/self_hosted_distro/tgi.md +++ b/docs/source/distributions/self_hosted_distro/tgi.md @@ -19,7 +19,7 @@ The `llamastack/distribution-tgi` distribution consists of the following provide | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | -| inference | `remote::tgi` | +| inference | `remote::tgi`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/docs/source/distributions/self_hosted_distro/together.md b/docs/source/distributions/self_hosted_distro/together.md index 936ae58f5..7af0dcf4d 100644 --- a/docs/source/distributions/self_hosted_distro/together.md +++ b/docs/source/distributions/self_hosted_distro/together.md @@ -18,7 +18,7 @@ The `llamastack/distribution-together` distribution consists of the following pr | agents | `inline::meta-reference` | | datasetio | `remote::huggingface`, `inline::localfs` | | eval | `inline::meta-reference` | -| inference | `remote::together` | +| inference | `remote::together`, `inline::sentence-transformers` | | safety | `inline::llama-guard` | | scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | | telemetry | `inline::meta-reference` | diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 73536491b..0c9c74518 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -178,6 +178,12 @@ class StackRun(Subcommand): # else must be venv since that is the only valid option left. current_venv = os.environ.get("VIRTUAL_ENV") venv = args.image_name or current_venv + if not venv: + cprint( + "No current virtual environment detected, please specify a virtual environment name with --image-name", + color="red", + ) + return script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh" run_args = [ script, @@ -206,5 +212,4 @@ class StackRun(Subcommand): if args.tls_keyfile and args.tls_certfile: run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile]) - run_with_pty(run_args) diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index 6a83836e6..bfb09af53 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -44,7 +44,6 @@ class SentenceTransformersInferenceImpl( pass async def register_model(self, model: Model) -> None: - _ = self._load_sentence_transformer_model(model.provider_resource_id) return model async def unregister_model(self, model_id: str) -> None: diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 346a2bd73..b0402f6a5 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -61,7 +61,10 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.inference, provider_type="inline::sentence-transformers", - pip_packages=["sentence-transformers"], + pip_packages=[ + "torch torchvision --index-url https://download.pytorch.org/whl/cpu", + "sentence-transformers --no-deps", + ], module="llama_stack.providers.inline.inference.sentence_transformers", config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", ), diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 33d880f30..95ea2dcf9 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -20,7 +20,18 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.tool_runtime, provider_type="inline::rag-runtime", - pip_packages=[], + pip_packages=[ + "blobfile", + "chardet", + "pypdf", + "tqdm", + "numpy", + "scikit-learn", + "scipy", + "nltk", + "sentencepiece", + "transformers", + ], module="llama_stack.providers.inline.tool_runtime.rag", config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig", api_dependencies=[Api.vector_io, Api.inference], diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 88a65397a..ff4f9caf5 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -14,33 +14,13 @@ from llama_stack.providers.datatypes import ( remote_provider_spec, ) -EMBEDDING_DEPS = [ - "blobfile", - "chardet", - "pypdf", - "tqdm", - "numpy", - "scikit-learn", - "scipy", - "nltk", - "sentencepiece", - "transformers", - # this happens to work because special dependencies are always installed last - # so if there was a regular torch installed first, this would be ignored - # we need a better way to do this to identify potential conflicts, etc. - # for now, this lets us significantly reduce the size of the container which - # does not have any "local" inference code (and hence does not need GPU-enabled torch) - "torch torchvision --index-url https://download.pytorch.org/whl/cpu", - "sentence-transformers --no-deps", -] - def available_providers() -> List[ProviderSpec]: return [ InlineProviderSpec( api=Api.vector_io, provider_type="inline::meta-reference", - pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], + pip_packages=["faiss-cpu"], module="llama_stack.providers.inline.vector_io.faiss", config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", deprecation_warning="Please use the `inline::faiss` provider instead.", @@ -49,24 +29,33 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.vector_io, provider_type="inline::faiss", - pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], + pip_packages=["faiss-cpu"], module="llama_stack.providers.inline.vector_io.faiss", config_class="llama_stack.providers.inline.vector_io.faiss.FaissVectorIOConfig", api_dependencies=[Api.inference], ), InlineProviderSpec( api=Api.vector_io, - provider_type="inline::sqlite_vec", - pip_packages=EMBEDDING_DEPS + ["sqlite-vec"], + provider_type="inline::sqlite-vec", + pip_packages=["sqlite-vec"], module="llama_stack.providers.inline.vector_io.sqlite_vec", config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig", api_dependencies=[Api.inference], ), + InlineProviderSpec( + api=Api.vector_io, + provider_type="inline::sqlite_vec", + pip_packages=["sqlite-vec"], + module="llama_stack.providers.inline.vector_io.sqlite_vec", + config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig", + deprecation_warning="Please use the `inline::sqlite-vec` provider (notice the hyphen instead of underscore) instead.", + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( adapter_type="chromadb", - pip_packages=EMBEDDING_DEPS + ["chromadb-client"], + pip_packages=["chromadb-client"], module="llama_stack.providers.remote.vector_io.chroma", config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", ), @@ -75,7 +64,7 @@ def available_providers() -> List[ProviderSpec]: InlineProviderSpec( api=Api.vector_io, provider_type="inline::chromadb", - pip_packages=EMBEDDING_DEPS + ["chromadb"], + pip_packages=["chromadb"], module="llama_stack.providers.inline.vector_io.chroma", config_class="llama_stack.providers.inline.vector_io.chroma.ChromaVectorIOConfig", api_dependencies=[Api.inference], @@ -84,7 +73,7 @@ def available_providers() -> List[ProviderSpec]: Api.vector_io, AdapterSpec( adapter_type="pgvector", - pip_packages=EMBEDDING_DEPS + ["psycopg2-binary"], + pip_packages=["psycopg2-binary"], module="llama_stack.providers.remote.vector_io.pgvector", config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", ), @@ -94,7 +83,7 @@ def available_providers() -> List[ProviderSpec]: Api.vector_io, AdapterSpec( adapter_type="weaviate", - pip_packages=EMBEDDING_DEPS + ["weaviate-client"], + pip_packages=["weaviate-client"], module="llama_stack.providers.remote.vector_io.weaviate", config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", @@ -115,7 +104,7 @@ def available_providers() -> List[ProviderSpec]: Api.vector_io, AdapterSpec( adapter_type="qdrant", - pip_packages=EMBEDDING_DEPS + ["qdrant-client"], + pip_packages=["qdrant-client"], module="llama_stack.providers.remote.vector_io.qdrant", config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", ), diff --git a/llama_stack/providers/tests/vector_io/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py index 1797d47a5..c29717a27 100644 --- a/llama_stack/providers/tests/vector_io/fixtures.py +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -61,7 +61,7 @@ def vector_io_sqlite_vec() -> ProviderFixture: providers=[ Provider( provider_id="sqlite_vec", - provider_type="inline::sqlite_vec", + provider_type="inline::sqlite-vec", config=SQLiteVectorIOConfig( kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), ).model_dump(), diff --git a/llama_stack/templates/cerebras/build.yaml b/llama_stack/templates/cerebras/build.yaml index 9d5ab1a52..ef6c43212 100644 --- a/llama_stack/templates/cerebras/build.yaml +++ b/llama_stack/templates/cerebras/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - remote::cerebras + - inline::sentence-transformers safety: - inline::llama-guard vector_io: diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py index c467579ac..544a50c03 100644 --- a/llama_stack/templates/cerebras/cerebras.py +++ b/llama_stack/templates/cerebras/cerebras.py @@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::cerebras"], + "inference": ["remote::cerebras", "inline::sentence-transformers"], "safety": ["inline::llama-guard"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "agents": ["inline::meta-reference"], diff --git a/llama_stack/templates/dell/build.yaml b/llama_stack/templates/dell/build.yaml index e2edb9386..05b98d56f 100644 --- a/llama_stack/templates/dell/build.yaml +++ b/llama_stack/templates/dell/build.yaml @@ -5,6 +5,7 @@ distribution_spec: providers: inference: - remote::tgi + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb diff --git a/llama_stack/templates/dell/dell.py b/llama_stack/templates/dell/dell.py index 116fbd285..8348beafd 100644 --- a/llama_stack/templates/dell/dell.py +++ b/llama_stack/templates/dell/dell.py @@ -20,7 +20,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::tgi"], + "inference": ["remote::tgi", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], diff --git a/llama_stack/templates/fireworks/build.yaml b/llama_stack/templates/fireworks/build.yaml index cdd60ec2a..a9c472c53 100644 --- a/llama_stack/templates/fireworks/build.yaml +++ b/llama_stack/templates/fireworks/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - remote::fireworks + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index 06b851551..4457296b0 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::fireworks"], + "inference": ["remote::fireworks", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], diff --git a/llama_stack/templates/hf-serverless/build.yaml b/llama_stack/templates/hf-serverless/build.yaml index f9303cfab..c0cc1e2c2 100644 --- a/llama_stack/templates/hf-serverless/build.yaml +++ b/llama_stack/templates/hf-serverless/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - remote::hf::serverless + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb diff --git a/llama_stack/templates/hf-serverless/hf_serverless.py b/llama_stack/templates/hf-serverless/hf_serverless.py index 46efb6f0b..af04e39d4 100644 --- a/llama_stack/templates/hf-serverless/hf_serverless.py +++ b/llama_stack/templates/hf-serverless/hf_serverless.py @@ -21,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::hf::serverless"], + "inference": ["remote::hf::serverless", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 48960c5ba..52a50b38a 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -5,8 +5,7 @@ distribution_spec: inference: - remote::ollama vector_io: - - inline::faiss - - inline::sqlite_vec + - inline::sqlite-vec - remote::chromadb - remote::pgvector safety: diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index 2b135c008..4f644c270 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -13,10 +13,6 @@ from llama_stack.distribution.datatypes import ( ShieldInput, ToolGroupInput, ) -from llama_stack.providers.inline.inference.sentence_transformers import ( - SentenceTransformersInferenceConfig, -) -from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -25,7 +21,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { "inference": ["remote::ollama"], - "vector_io": ["inline::faiss", "inline::sqlite_vec", "remote::chromadb", "remote::pgvector"], + "vector_io": ["inline::sqlite-vec", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], @@ -45,19 +41,9 @@ def get_distribution_template() -> DistributionTemplate: provider_type="remote::ollama", config=OllamaImplConfig.sample_run_config(), ) - embedding_provider = Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - config=SentenceTransformersInferenceConfig.sample_run_config(), - ) - vector_io_provider_faiss = Provider( - provider_id="faiss", - provider_type="inline::faiss", - config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), - ) vector_io_provider_sqlite = Provider( - provider_id="sqlite_vec", - provider_type="inline::sqlite_vec", + provider_id="sqlite-vec", + provider_type="inline::sqlite-vec", config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), ) @@ -104,19 +90,16 @@ def get_distribution_template() -> DistributionTemplate: run_configs={ "run.yaml": RunConfigSettings( provider_overrides={ - "inference": [inference_provider, embedding_provider], - "vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite], + "inference": [inference_provider], + "vector_io": [vector_io_provider_sqlite], }, - default_models=[inference_model, embedding_model], + default_models=[inference_model], default_tool_groups=default_tool_groups, ), "run-with-safety.yaml": RunConfigSettings( provider_overrides={ - "inference": [ - inference_provider, - embedding_provider, - ], - "vector_io": [vector_io_provider_faiss, vector_io_provider_faiss], + "inference": [inference_provider], + "vector_io": [vector_io_provider_sqlite], "safety": [ Provider( provider_id="llama-guard", diff --git a/llama_stack/templates/ollama/run-with-safety.yaml b/llama_stack/templates/ollama/run-with-safety.yaml index 7cf527c04..063840a50 100644 --- a/llama_stack/templates/ollama/run-with-safety.yaml +++ b/llama_stack/templates/ollama/run-with-safety.yaml @@ -16,24 +16,11 @@ providers: provider_type: remote::ollama config: url: ${env.OLLAMA_URL:http://localhost:11434} - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers - config: {} vector_io: - - provider_id: faiss - provider_type: inline::faiss + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db - - provider_id: faiss - provider_type: inline::faiss - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 1f45fc228..d64e07347 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -16,19 +16,9 @@ providers: provider_type: remote::ollama config: url: ${env.OLLAMA_URL:http://localhost:11434} - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers - config: {} vector_io: - - provider_id: faiss - provider_type: inline::faiss - config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db - - provider_id: sqlite_vec - provider_type: inline::sqlite_vec + - provider_id: sqlite-vec + provider_type: inline::sqlite-vec config: db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db safety: @@ -97,12 +87,6 @@ models: model_id: ${env.INFERENCE_MODEL} provider_id: ollama model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: ollama - provider_model_id: all-minilm:latest - model_type: embedding shields: [] vector_dbs: [] datasets: [] diff --git a/llama_stack/templates/remote-vllm/build.yaml b/llama_stack/templates/remote-vllm/build.yaml index 74d9f32d9..ccb328c1c 100644 --- a/llama_stack/templates/remote-vllm/build.yaml +++ b/llama_stack/templates/remote-vllm/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - remote::vllm + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb diff --git a/llama_stack/templates/remote-vllm/vllm.py b/llama_stack/templates/remote-vllm/vllm.py index 40a2d541d..10d291456 100644 --- a/llama_stack/templates/remote-vllm/vllm.py +++ b/llama_stack/templates/remote-vllm/vllm.py @@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::vllm"], + "inference": ["remote::vllm", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], diff --git a/llama_stack/templates/tgi/build.yaml b/llama_stack/templates/tgi/build.yaml index 8bc628158..9fe79647c 100644 --- a/llama_stack/templates/tgi/build.yaml +++ b/llama_stack/templates/tgi/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - remote::tgi + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb diff --git a/llama_stack/templates/tgi/tgi.py b/llama_stack/templates/tgi/tgi.py index 71718a93d..9b80414f9 100644 --- a/llama_stack/templates/tgi/tgi.py +++ b/llama_stack/templates/tgi/tgi.py @@ -23,7 +23,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::tgi"], + "inference": ["remote::tgi", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], diff --git a/llama_stack/templates/together/build.yaml b/llama_stack/templates/together/build.yaml index 90ee5bcee..a8a6de28d 100644 --- a/llama_stack/templates/together/build.yaml +++ b/llama_stack/templates/together/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - remote::together + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index d275b7238..8d0e2353c 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -25,7 +25,7 @@ from llama_stack.templates.template import DistributionTemplate, RunConfigSettin def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["remote::together"], + "inference": ["remote::together", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml index d24046613..8eb44dc1b 100644 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ b/llama_stack/templates/vllm-gpu/build.yaml @@ -4,6 +4,7 @@ distribution_spec: providers: inference: - inline::vllm + - inline::sentence-transformers vector_io: - inline::faiss - remote::chromadb diff --git a/llama_stack/templates/vllm-gpu/vllm.py b/llama_stack/templates/vllm-gpu/vllm.py index 31900687b..8cdec589e 100644 --- a/llama_stack/templates/vllm-gpu/vllm.py +++ b/llama_stack/templates/vllm-gpu/vllm.py @@ -20,7 +20,7 @@ from llama_stack.templates.template import ( def get_distribution_template() -> DistributionTemplate: providers = { - "inference": ["inline::vllm"], + "inference": ["inline::vllm", "inline::sentence-transformers"], "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], "safety": ["inline::llama-guard"], "agents": ["inline::meta-reference"], From 0fe071764f8902aa41487e77df5faf292d47ba5f Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Fri, 21 Feb 2025 11:48:27 -0800 Subject: [PATCH 23/45] feat(1/n): api: unify agents for handling server & client tools (#1178) # Problem Our current Agent framework has discrepancies in definition on how we handle server side and client side tools. 1. Server Tools: a single Turn is returned including `ToolExecutionStep` in agenst 2. Client Tools: `create_agent_turn` is called in loop with client agent lib yielding the agent chunk https://github.com/meta-llama/llama-stack-client-python/blob/ad6ffc63df658674f275267b1befc2b7046dbf33/src/llama_stack_client/lib/agents/agent.py#L186-L211 This makes it inconsistent to work with server & client tools. It also complicates the logs to telemetry to get information about agents turn / history for observability. #### Principle The same `turn_id` should be used to represent the steps required to complete a user message including client tools. ## Solution 1. `AgentTurnResponseEventType.turn_awaiting_input` status to indicate that the current turn is not completed, and awaiting tool input 2. `continue_agent_turn` endpoint to update agent turn with client's tool response. # What does this PR do? - Skeleton API as example ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] - Just API update, no functionality change ``` llama stack run + client-sdk test ``` image [//]: # (## Documentation) --- docs/_static/llama-stack-spec.html | 113 ++++++++++- docs/_static/llama-stack-spec.yaml | 82 ++++++++ llama_stack/apis/agents/agents.py | 48 +++++ .../agents/meta_reference/agent_instance.py | 179 ++++++++++++++++-- .../inline/agents/meta_reference/agents.py | 31 +++ .../agents/meta_reference/persistence.py | 14 +- tests/client-sdk/agents/test_agents.py | 8 +- 7 files changed, 454 insertions(+), 21 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index fab7c802e..ce08e041f 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -2315,6 +2315,70 @@ } } }, + "/v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume": { + "post": { + "responses": { + "200": { + "description": "A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Turn" + } + }, + "text/event-stream": { + "schema": { + "$ref": "#/components/schemas/AgentTurnResponseStreamChunk" + } + } + } + } + }, + "tags": [ + "Agents" + ], + "description": "Resume an agent turn with executed tool call responses.\nWhen a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.", + "parameters": [ + { + "name": "agent_id", + "in": "path", + "description": "The ID of the agent to resume.", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "session_id", + "in": "path", + "description": "The ID of the session to resume.", + "required": true, + "schema": { + "type": "string" + } + }, + { + "name": "turn_id", + "in": "path", + "description": "The ID of the turn to resume.", + "required": true, + "schema": { + "type": "string" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ResumeAgentTurnRequest" + } + } + }, + "required": true + } + } + }, "/v1/eval/benchmarks/{benchmark_id}/jobs": { "post": { "responses": { @@ -4226,6 +4290,9 @@ }, "tool_config": { "$ref": "#/components/schemas/ToolConfig" + }, + "allow_turn_resume": { + "type": "boolean" } }, "additionalProperties": false, @@ -4612,6 +4679,9 @@ }, { "$ref": "#/components/schemas/AgentTurnResponseTurnCompletePayload" + }, + { + "$ref": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload" } ], "discriminator": { @@ -4621,7 +4691,8 @@ "step_progress": "#/components/schemas/AgentTurnResponseStepProgressPayload", "step_complete": "#/components/schemas/AgentTurnResponseStepCompletePayload", "turn_start": "#/components/schemas/AgentTurnResponseTurnStartPayload", - "turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload" + "turn_complete": "#/components/schemas/AgentTurnResponseTurnCompletePayload", + "turn_awaiting_input": "#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload" } } }, @@ -4784,6 +4855,25 @@ "title": "AgentTurnResponseStreamChunk", "description": "streamed agent turn completion response." }, + "AgentTurnResponseTurnAwaitingInputPayload": { + "type": "object", + "properties": { + "event_type": { + "type": "string", + "const": "turn_awaiting_input", + "default": "turn_awaiting_input" + }, + "turn": { + "$ref": "#/components/schemas/Turn" + } + }, + "additionalProperties": false, + "required": [ + "event_type", + "turn" + ], + "title": "AgentTurnResponseTurnAwaitingInputPayload" + }, "AgentTurnResponseTurnCompletePayload": { "type": "object", "properties": { @@ -8046,6 +8136,27 @@ ], "title": "RegisterVectorDbRequest" }, + "ResumeAgentTurnRequest": { + "type": "object", + "properties": { + "tool_responses": { + "type": "array", + "items": { + "$ref": "#/components/schemas/ToolResponseMessage" + }, + "description": "The tool call responses to resume the turn with." + }, + "stream": { + "type": "boolean", + "description": "Whether to stream the response." + } + }, + "additionalProperties": false, + "required": [ + "tool_responses" + ], + "title": "ResumeAgentTurnRequest" + }, "RunEvalRequest": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index fc57bf258..0e4955a5c 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1401,6 +1401,53 @@ paths: schema: $ref: '#/components/schemas/QueryTracesRequest' required: true + /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume: + post: + responses: + '200': + description: >- + A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk + objects. + content: + application/json: + schema: + $ref: '#/components/schemas/Turn' + text/event-stream: + schema: + $ref: '#/components/schemas/AgentTurnResponseStreamChunk' + tags: + - Agents + description: >- + Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client + side tool calls, this endpoint can be used to submit the outputs from the + tool calls once they are ready. + parameters: + - name: agent_id + in: path + description: The ID of the agent to resume. + required: true + schema: + type: string + - name: session_id + in: path + description: The ID of the session to resume. + required: true + schema: + type: string + - name: turn_id + in: path + description: The ID of the turn to resume. + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ResumeAgentTurnRequest' + required: true /v1/eval/benchmarks/{benchmark_id}/jobs: post: responses: @@ -2740,6 +2787,8 @@ components: $ref: '#/components/schemas/AgentTool' tool_config: $ref: '#/components/schemas/ToolConfig' + allow_turn_resume: + type: boolean additionalProperties: false required: - messages @@ -2992,6 +3041,7 @@ components: - $ref: '#/components/schemas/AgentTurnResponseStepCompletePayload' - $ref: '#/components/schemas/AgentTurnResponseTurnStartPayload' - $ref: '#/components/schemas/AgentTurnResponseTurnCompletePayload' + - $ref: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload' discriminator: propertyName: event_type mapping: @@ -3000,6 +3050,7 @@ components: step_complete: '#/components/schemas/AgentTurnResponseStepCompletePayload' turn_start: '#/components/schemas/AgentTurnResponseTurnStartPayload' turn_complete: '#/components/schemas/AgentTurnResponseTurnCompletePayload' + turn_awaiting_input: '#/components/schemas/AgentTurnResponseTurnAwaitingInputPayload' AgentTurnResponseStepCompletePayload: type: object properties: @@ -3106,6 +3157,21 @@ components: - event title: AgentTurnResponseStreamChunk description: streamed agent turn completion response. + "AgentTurnResponseTurnAwaitingInputPayload": + type: object + properties: + event_type: + type: string + const: turn_awaiting_input + default: turn_awaiting_input + turn: + $ref: '#/components/schemas/Turn' + additionalProperties: false + required: + - event_type + - turn + title: >- + AgentTurnResponseTurnAwaitingInputPayload AgentTurnResponseTurnCompletePayload: type: object properties: @@ -5205,6 +5271,22 @@ components: - vector_db_id - embedding_model title: RegisterVectorDbRequest + ResumeAgentTurnRequest: + type: object + properties: + tool_responses: + type: array + items: + $ref: '#/components/schemas/ToolResponseMessage' + description: >- + The tool call responses to resume the turn with. + stream: + type: boolean + description: Whether to stream the response. + additionalProperties: false + required: + - tool_responses + title: ResumeAgentTurnRequest RunEvalRequest: type: object properties: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 367648ded..c904fdbef 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -194,6 +194,7 @@ class AgentTurnResponseEventType(Enum): turn_start = "turn_start" turn_complete = "turn_complete" + turn_awaiting_input = "turn_awaiting_input" @json_schema_type @@ -235,6 +236,14 @@ class AgentTurnResponseTurnCompletePayload(BaseModel): turn: Turn +@json_schema_type +class AgentTurnResponseTurnAwaitingInputPayload(BaseModel): + event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input.value] = ( + AgentTurnResponseEventType.turn_awaiting_input.value + ) + turn: Turn + + AgentTurnResponseEventPayload = register_schema( Annotated[ Union[ @@ -243,6 +252,7 @@ AgentTurnResponseEventPayload = register_schema( AgentTurnResponseStepCompletePayload, AgentTurnResponseTurnStartPayload, AgentTurnResponseTurnCompletePayload, + AgentTurnResponseTurnAwaitingInputPayload, ], Field(discriminator="event_type"), ], @@ -286,6 +296,18 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): stream: Optional[bool] = False tool_config: Optional[ToolConfig] = None + # TODO (xiyan): temporary flag, will remove for 0.1.5 + allow_turn_resume: Optional[bool] = False + + +@json_schema_type +class AgentTurnResumeRequest(BaseModel): + agent_id: str + session_id: str + turn_id: str + tool_responses: List[ToolResponseMessage] + stream: Optional[bool] = False + @json_schema_type class AgentTurnResponseStreamChunk(BaseModel): @@ -333,8 +355,34 @@ class Agents(Protocol): documents: Optional[List[Document]] = None, toolgroups: Optional[List[AgentToolGroup]] = None, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... + @webmethod( + route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume", + method="POST", + ) + async def resume_agent_turn( + self, + agent_id: str, + session_id: str, + turn_id: str, + tool_responses: List[ToolResponseMessage], + stream: Optional[bool] = False, + ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: + """Resume an agent turn with executed tool call responses. + + When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready. + + :param agent_id: The ID of the agent to resume. + :param session_id: The ID of the session to resume. + :param turn_id: The ID of the turn to resume. + :param tool_responses: The tool call responses to resume the turn with. + :param stream: Whether to stream the response. + :returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects. + """ + ... + @webmethod( route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET", diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 1c21df57f..edd253356 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -30,8 +30,10 @@ from llama_stack.apis.agents import ( AgentTurnResponseStepProgressPayload, AgentTurnResponseStepStartPayload, AgentTurnResponseStreamChunk, + AgentTurnResponseTurnAwaitingInputPayload, AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnStartPayload, + AgentTurnResumeRequest, Attachment, Document, InferenceStep, @@ -62,7 +64,11 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO -from llama_stack.models.llama.datatypes import BuiltinTool, ToolCall, ToolParamDefinition +from llama_stack.models.llama.datatypes import ( + BuiltinTool, + ToolCall, + ToolParamDefinition, +) from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing @@ -151,6 +157,15 @@ class ChatAgent(ShieldRunnerMixin): async def create_session(self, name: str) -> str: return await self.storage.create_session(name) + async def get_messages_from_turns(self, turns: List[Turn]) -> List[Message]: + messages = [] + if self.agent_config.instructions != "": + messages.append(SystemMessage(content=self.agent_config.instructions)) + + for turn in turns: + messages.extend(self.turn_to_messages(turn)) + return messages + async def create_and_execute_turn(self, request: AgentTurnCreateRequest) -> AsyncGenerator: with tracing.span("create_and_execute_turn") as span: span.set_attribute("session_id", request.session_id) @@ -163,14 +178,7 @@ class ChatAgent(ShieldRunnerMixin): raise ValueError(f"Session {request.session_id} not found") turns = await self.storage.get_session_turns(request.session_id) - - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) - - for i, turn in enumerate(turns): - messages.extend(self.turn_to_messages(turn)) - + messages = await self.get_messages_from_turns(turns) messages.extend(request.messages) turn_id = str(uuid.uuid4()) @@ -222,13 +230,136 @@ class ChatAgent(ShieldRunnerMixin): ) await self.storage.add_turn_to_session(request.session_id, turn) - chunk = AgentTurnResponseStreamChunk( + if output_message.tool_calls and request.allow_turn_resume: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + + yield chunk + + async def resume_turn(self, request: AgentTurnResumeRequest) -> AsyncGenerator: + with tracing.span("resume_turn") as span: + span.set_attribute("agent_id", self.agent_id) + span.set_attribute("session_id", request.session_id) + span.set_attribute("turn_id", request.turn_id) + span.set_attribute("request", request.model_dump_json()) + assert request.stream is True, "Non-streaming not supported" + + session_info = await self.storage.get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") + + turns = await self.storage.get_session_turns(request.session_id) + messages = await self.get_messages_from_turns(turns) + messages.extend(request.tool_responses) + + last_turn_messages = [ + x for x in messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) + ] + + # get the steps from the turn id + steps = [] + if len(turns) > 0: + steps = turns[-1].steps + + # mark tool execution step as complete + # if there's no tool execution in progress step (due to storage, or tool call parsing on client), + # we'll create a new tool execution step with current time + in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( + request.session_id, request.turn_id + ) + now = datetime.now() + tool_execution_step = ToolExecutionStep( + step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), + turn_id=request.turn_id, + tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), + tool_responses=[ + ToolResponse( + call_id=x.call_id, + tool_name=x.tool_name, + content=x.content, + ) + for x in request.tool_responses + ], + completed_at=now, + started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), + ) + steps.append(tool_execution_step) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - turn=turn, + payload=AgentTurnResponseStepCompletePayload( + step_type=StepType.tool_execution.value, + step_id=tool_execution_step.step_id, + step_details=tool_execution_step, ) ) ) + + output_message = None + async for chunk in self.run( + session_id=request.session_id, + turn_id=request.turn_id, + input_messages=messages, + sampling_params=self.agent_config.sampling_params, + stream=request.stream, + ): + if isinstance(chunk, CompletionMessage): + output_message = chunk + continue + + assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" + event = chunk.event + if event.payload.event_type == AgentTurnResponseEventType.step_complete.value: + steps.append(event.payload.step_details) + + yield chunk + + assert output_message is not None + + last_turn_start_time = datetime.now() + if len(turns) > 0: + last_turn_start_time = turns[-1].started_at + + turn = Turn( + turn_id=request.turn_id, + session_id=request.session_id, + input_messages=last_turn_messages, + output_message=output_message, + started_at=last_turn_start_time, + completed_at=datetime.now(), + steps=steps, + ) + await self.storage.add_turn_to_session(request.session_id, turn) + + if output_message.tool_calls: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnAwaitingInputPayload( + turn=turn, + ) + ) + ) + else: + chunk = AgentTurnResponseStreamChunk( + event=AgentTurnResponseEvent( + payload=AgentTurnResponseTurnCompletePayload( + turn=turn, + ) + ) + ) + yield chunk async def run( @@ -611,11 +742,7 @@ class ChatAgent(ShieldRunnerMixin): input_messages = input_messages + [message] else: log.info(f"{str(message)}") - tool_call = message.tool_calls[0] - if tool_call.tool_name in client_tools: - yield message - return - + # 1. Start the tool execution step and progress step_id = str(uuid.uuid4()) yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -625,6 +752,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) + tool_call = message.tool_calls[0] yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepProgressPayload( @@ -639,6 +767,23 @@ class ChatAgent(ShieldRunnerMixin): ) ) + # If tool is a client tool, yield CompletionMessage and return + if tool_call.tool_name in client_tools: + await self.storage.set_in_progress_tool_call_step( + session_id, + turn_id, + ToolExecutionStep( + step_id=step_id, + turn_id=turn_id, + tool_calls=[tool_call], + tool_responses=[], + started_at=datetime.now(), + ), + ) + yield message + return + + # If tool is a builtin server tool, execute it tool_name = tool_call.tool_name if isinstance(tool_name, BuiltinTool): tool_name = tool_name.value diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index e3c18d112..8a4d91238 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ from llama_stack.apis.agents import ( AgentStepResponse, AgentToolGroup, AgentTurnCreateRequest, + AgentTurnResumeRequest, Document, Session, Turn, @@ -146,6 +147,7 @@ class MetaReferenceAgentsImpl(Agents): documents: Optional[List[Document]] = None, stream: Optional[bool] = False, tool_config: Optional[ToolConfig] = None, + allow_turn_resume: Optional[bool] = False, ) -> AsyncGenerator: request = AgentTurnCreateRequest( agent_id=agent_id, @@ -155,6 +157,7 @@ class MetaReferenceAgentsImpl(Agents): toolgroups=toolgroups, documents=documents, tool_config=tool_config, + allow_turn_resume=allow_turn_resume, ) if stream: return self._create_agent_turn_streaming(request) @@ -169,6 +172,34 @@ class MetaReferenceAgentsImpl(Agents): async for event in agent.create_and_execute_turn(request): yield event + async def resume_agent_turn( + self, + agent_id: str, + session_id: str, + turn_id: str, + tool_responses: List[ToolResponseMessage], + stream: Optional[bool] = False, + ) -> AsyncGenerator: + request = AgentTurnResumeRequest( + agent_id=agent_id, + session_id=session_id, + turn_id=turn_id, + tool_responses=tool_responses, + stream=stream, + ) + if stream: + return self._continue_agent_turn_streaming(request) + else: + raise NotImplementedError("Non-streaming agent turns not yet implemented") + + async def _continue_agent_turn_streaming( + self, + request: AgentTurnResumeRequest, + ) -> AsyncGenerator: + agent = await self.get_agent(request.agent_id) + async for event in agent.resume_turn(request): + yield event + async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: turn = await self.persistence_store.get(f"session:{agent_id}:{session_id}:{turn_id}") turn = json.loads(turn) diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 4b8ad6d4a..3c3866873 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -12,7 +12,7 @@ from typing import List, Optional from pydantic import BaseModel -from llama_stack.apis.agents import Turn +from llama_stack.apis.agents import ToolExecutionStep, Turn from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) @@ -84,3 +84,15 @@ class AgentPersistence: continue turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns + + async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep): + await self.kvstore.set( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + value=step.model_dump_json(), + ) + + async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> Optional[ToolExecutionStep]: + value = await self.kvstore.get( + key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}", + ) + return ToolExecutionStep(**json.loads(value)) if value else None diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index e5380d357..781095d2b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -19,8 +19,12 @@ from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.tool_def_param import Parameter -from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig -from llama_stack.apis.agents.agents import ToolChoice +from llama_stack.apis.agents.agents import ( + AgentConfig as Server__AgentConfig, +) +from llama_stack.apis.agents.agents import ( + ToolChoice, +) class TestClientTool(ClientTool): From 36162c8c82648843febfbe359d237e362a0b118a Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 12:51:38 -0800 Subject: [PATCH 24/45] fix(ollama): register model with the helper first so it gets normalized --- llama_stack/providers/remote/inference/ollama/ollama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index f61ac9898..058bbeeee 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -280,6 +280,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): return EmbeddingsResponse(embeddings=embeddings) async def register_model(self, model: Model) -> Model: + model = await self.register_helper.register_model(model) if model.model_type == ModelType.embedding: log.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") await self.client.pull(model.provider_resource_id) @@ -292,7 +293,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): f"Model '{model.provider_resource_id}' is not available in Ollama. Available models: {', '.join(available_models)}" ) - return await self.register_helper.register_model(model) + return model async def convert_message_to_openai_dict_for_ollama(message: Message) -> List[dict]: From 25fddccfd80670234ab7a32b8cdf381ca3282e74 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 21 Feb 2025 13:15:31 -0800 Subject: [PATCH 25/45] feat: tool outputs metadata (#1155) Summary: Allows tools to output metadata. This is useful for evaluating tool outputs, e.g. RAG tool will output document IDs, which can be used to score recall. Will need to make a similar change on the client side to support ClientTool outputting metadata. Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --- docs/_static/llama-stack-spec.html | 78 +++++++++++++++++++ docs/_static/llama-stack-spec.yaml | 32 ++++++++ llama_stack/apis/inference/inference.py | 1 + llama_stack/apis/tools/rag_tool.py | 1 + llama_stack/apis/tools/tools.py | 1 + .../agents/meta_reference/agent_instance.py | 38 ++++----- .../inline/tool_runtime/rag/memory.py | 7 +- tests/client-sdk/agents/test_agents.py | 11 ++- 8 files changed, 141 insertions(+), 28 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index ce08e041f..2a9f4b6f7 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -4521,6 +4521,31 @@ }, "content": { "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -6746,6 +6771,31 @@ }, "error_code": { "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, @@ -7595,9 +7645,37 @@ "properties": { "content": { "$ref": "#/components/schemas/InterleavedContent" + }, + "metadata": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } } }, "additionalProperties": false, + "required": [ + "metadata" + ], "title": "RAGQueryResult" }, "QueryChunksRequest": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 0e4955a5c..a2329e47a 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -2945,6 +2945,16 @@ components: - type: string content: $ref: '#/components/schemas/InterleavedContent' + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false required: - call_id @@ -4381,6 +4391,16 @@ components: type: string error_code: type: integer + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false required: - content @@ -4954,7 +4974,19 @@ components: properties: content: $ref: '#/components/schemas/InterleavedContent' + metadata: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object additionalProperties: false + required: + - metadata title: RAGQueryResult QueryChunksRequest: type: object diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index d83506dd4..e517d9c3c 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -165,6 +165,7 @@ class ToolResponse(BaseModel): call_id: str tool_name: Union[BuiltinTool, str] content: InterleavedContent + metadata: Optional[Dict[str, Any]] = None @field_validator("tool_name", mode="before") @classmethod diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index cff8eeefe..2b9ef10d8 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -26,6 +26,7 @@ class RAGDocument(BaseModel): @json_schema_type class RAGQueryResult(BaseModel): content: Optional[InterleavedContent] = None + metadata: Dict[str, Any] = Field(default_factory=dict) @json_schema_type diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index b83be127f..a4d84edbe 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -72,6 +72,7 @@ class ToolInvocationResult(BaseModel): content: InterleavedContent error_message: Optional[str] = None error_code: Optional[int] = None + metadata: Optional[Dict[str, Any]] = None class ToolStore(Protocol): diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index edd253356..560215b25 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -62,7 +62,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolRuntime +from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -587,6 +587,7 @@ class ChatAgent(ShieldRunnerMixin): call_id="", tool_name=MEMORY_QUERY_TOOL, content=retrieved_context or [], + metadata=result.metadata, ) ], ), @@ -795,13 +796,21 @@ class ChatAgent(ShieldRunnerMixin): }, ) as span: tool_execution_start_time = datetime.now() - result_messages = await execute_tool_call_maybe( + tool_call = message.tool_calls[0] + tool_result = await execute_tool_call_maybe( self.tool_runtime_api, session_id, - [message], + tool_call, toolgroup_args, tool_to_group, ) + result_messages = [ + ToolResponseMessage( + call_id=tool_call.call_id, + tool_name=tool_call.tool_name, + content=tool_result.content, + ) + ] assert len(result_messages) == 1, "Currently not supporting multiple messages" result_message = result_messages[0] span.set_attribute("output", result_message.model_dump_json()) @@ -820,6 +829,7 @@ class ChatAgent(ShieldRunnerMixin): call_id=result_message.call_id, tool_name=result_message.tool_name, content=result_message.content, + metadata=tool_result.metadata, ) ], started_at=tool_execution_start_time, @@ -1058,19 +1068,10 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa async def execute_tool_call_maybe( tool_runtime_api: ToolRuntime, session_id: str, - messages: List[CompletionMessage], + tool_call: ToolCall, toolgroup_args: Dict[str, Dict[str, Any]], tool_to_group: Dict[str, str], -) -> List[ToolResponseMessage]: - # While Tools.run interface takes a list of messages, - # All tools currently only run on a single message - # When this changes, we can drop this assert - # Whether to call tools on each message and aggregate - # or aggregate and call tool once, reamins to be seen. - assert len(messages) == 1, "Expected single message" - message = messages[0] - - tool_call = message.tool_calls[0] +) -> ToolInvocationResult: name = tool_call.tool_name group_name = tool_to_group.get(name, None) if group_name is None: @@ -1091,14 +1092,7 @@ async def execute_tool_call_maybe( **tool_call_args, ), ) - - return [ - ToolResponseMessage( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=result.content, - ) - ] + return result def _interpret_content_as_attachment( diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index a6cd57923..306bd78a6 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -119,10 +119,10 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # sort by score chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) - + chunks = chunks[: query_config.max_chunks] tokens = 0 picked = [] - for c in chunks[: query_config.max_chunks]: + for c in chunks: metadata = c.metadata tokens += metadata["token_count"] if tokens > query_config.max_tokens_in_context: @@ -146,6 +146,9 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): text="\n=== END-RETRIEVED-CONTEXT ===\n", ), ], + metadata={ + "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], + }, ) async def list_runtime_tools( diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 781095d2b..23ae601e4 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -457,6 +457,7 @@ def test_rag_agent(llama_stack_client, agent_config): vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", embedding_dimension=384, + provider_id="faiss", ) llama_stack_client.tool_runtime.rag_tool.insert( documents=documents, @@ -492,11 +493,13 @@ def test_rag_agent(llama_stack_client, agent_config): response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - assert "Tool:query_from_memory" in logs_str - assert expected_kw in logs_str.lower() + # rag is called + assert response.steps[0].tool_calls[0].tool_name == "query_from_memory" + # document ids are present in metadata + assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"] + assert expected_kw in response.output_message.content def test_rag_and_code_agent(llama_stack_client, agent_config): From 9bbe34694dc450a59692f6aa5e33b7020e57b199 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Feb 2025 22:15:40 +0100 Subject: [PATCH 26/45] ci: add mypy for static type checking (#1101) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Enable mypy to run in the CI on a subset of the repository - Fix a few mypy errors - Run mypy from pre-commit Signed-off-by: Sébastien Han [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: Sébastien Han --- .pre-commit-config.yaml | 27 ++-- llama_stack/apis/common/type_system.py | 15 ++- llama_stack/schema_utils.py | 11 +- llama_stack/scripts/distro_codegen.py | 6 +- llama_stack/scripts/run_client_sdk_tests.py | 2 +- pyproject.toml | 23 ++++ requirements.txt | 2 +- uv.lock | 130 ++++++++++---------- 8 files changed, 125 insertions(+), 91 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56e35aa6e..85cb1b91a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -45,23 +45,26 @@ repos: hooks: - id: uv-export args: [ - "--frozen", - "--no-hashes", - "--no-emit-project", + "--frozen", + "--no-hashes", + "--no-emit-project", "--output-file=requirements.txt" ] files: ^pyproject\.toml$ - id: uv-sync -# - repo: https://github.com/pre-commit/mirrors-mypy -# rev: v1.14.0 -# hooks: -# - id: mypy -# additional_dependencies: -# - types-requests -# - types-setuptools -# - pydantic -# args: [--ignore-missing-imports] +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.15.0 + hooks: + - id: mypy + additional_dependencies: + - uv==0.6.2 + - mypy + - pytest + - rich + - types-requests + - pydantic + pass_filenames: false # - repo: https://github.com/jsh9/pydoclint # rev: d88180a8632bb1602a4d81344085cf320f288c5a diff --git a/llama_stack/apis/common/type_system.py b/llama_stack/apis/common/type_system.py index 139ae8875..d7746df8d 100644 --- a/llama_stack/apis/common/type_system.py +++ b/llama_stack/apis/common/type_system.py @@ -91,15 +91,18 @@ ParamType = register_schema( name="ParamType", ) +""" # TODO: recursive definition of ParamType in these containers # will cause infinite recursion in OpenAPI generation script # since we are going with ChatCompletionInputType and CompletionInputType # we don't need to worry about ArrayType/ObjectType/UnionType for now -# ArrayType.model_rebuild() -# ObjectType.model_rebuild() -# UnionType.model_rebuild() +ArrayType.model_rebuild() +ObjectType.model_rebuild() +UnionType.model_rebuild() -# class CustomType(BaseModel): -# type: Literal["custom"] = "custom" -# validator_class: str +class CustomType(BaseModel): +pylint: disable=syntax-error + type: Literal["custom"] = "custom" + validator_class: str +""" diff --git a/llama_stack/schema_utils.py b/llama_stack/schema_utils.py index 581404844..ad92338e6 100644 --- a/llama_stack/schema_utils.py +++ b/llama_stack/schema_utils.py @@ -5,12 +5,10 @@ # the root directory of this source tree. from dataclasses import dataclass -from typing import Any, Callable, List, Optional, TypeVar +from typing import Any, Callable, List, Optional, Protocol, TypeVar from .strong_typing.schema import json_schema_type, register_schema # noqa: F401 -T = TypeVar("T") - @dataclass class WebMethod: @@ -22,6 +20,13 @@ class WebMethod: raw_bytes_request_body: Optional[bool] = False +class HasWebMethod(Protocol): + __webmethod__: WebMethod + + +T = TypeVar("T", bound=HasWebMethod) # Bound T to classes that match this protocol + + def webmethod( route: Optional[str] = None, method: Optional[str] = None, diff --git a/llama_stack/scripts/distro_codegen.py b/llama_stack/scripts/distro_codegen.py index 1c44b4625..76c7283eb 100644 --- a/llama_stack/scripts/distro_codegen.py +++ b/llama_stack/scripts/distro_codegen.py @@ -11,7 +11,7 @@ import subprocess import sys from functools import partial from pathlib import Path -from typing import Iterator +from typing import Iterable from rich.progress import Progress, SpinnerColumn, TextColumn @@ -39,7 +39,7 @@ class ChangedPathTracker: return self._changed_paths -def find_template_dirs(templates_dir: Path) -> Iterator[Path]: +def find_template_dirs(templates_dir: Path) -> Iterable[Path]: """Find immediate subdirectories in the templates folder.""" if not templates_dir.exists(): raise FileNotFoundError(f"Templates directory not found: {templates_dir}") @@ -90,7 +90,7 @@ def check_for_changes(change_tracker: ChangedPathTracker) -> bool: return has_changes -def collect_template_dependencies(template_dir: Path) -> tuple[str, list[str]]: +def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]: try: module_name = f"llama_stack.templates.{template_dir.name}" module = importlib.import_module(module_name) diff --git a/llama_stack/scripts/run_client_sdk_tests.py b/llama_stack/scripts/run_client_sdk_tests.py index 1e2ef1ac8..6aaeb3273 100644 --- a/llama_stack/scripts/run_client_sdk_tests.py +++ b/llama_stack/scripts/run_client_sdk_tests.py @@ -52,7 +52,7 @@ def main(parser: argparse.ArgumentParser): pytest_args, "-s", "-v", - REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH, + str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH), ] ) diff --git a/pyproject.toml b/pyproject.toml index c8ed5737b..2bad04163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -158,3 +158,26 @@ ignore = [ "B007", "B008", ] + +[tool.mypy] +mypy_path = ["llama_stack"] +packages = ["llama_stack"] +disable_error_code = [] +warn_return_any = true +# # honor excludes by not following there through imports +follow_imports = "silent" +exclude = [ + # As we fix more and more of these, we should remove them from the list + "llama_stack/providers", + "llama_stack/distribution", + "llama_stack/apis", + "llama_stack/cli", + "llama_stack/models", + "llama_stack/strong_typing", + "llama_stack/templates", +] + +[[tool.mypy.overrides]] +# packages that lack typing annotations, do not have stubs, or are unavailable. +module = ["llama_models.*", "yaml", "fire"] +ignore_missing_imports = true diff --git a/requirements.txt b/requirements.txt index 02e1a8655..014db083a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ fsspec==2025.2.0 h11==0.14.0 httpcore==1.0.7 httpx==0.28.1 -huggingface-hub==0.28.1 +huggingface-hub==0.29.0 idna==3.10 jinja2==3.1.5 jsonschema==4.23.0 diff --git a/uv.lock b/uv.lock index ce633c174..3cf05f17d 100644 --- a/uv.lock +++ b/uv.lock @@ -584,7 +584,7 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.28.1" +version = "0.29.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, @@ -595,9 +595,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e7/ce/a734204aaae6c35a22f9956ebcd8d8708ae5b842e15d6f42bd6f49e634a4/huggingface_hub-0.28.1.tar.gz", hash = "sha256:893471090c98e3b6efbdfdacafe4052b20b84d59866fb6f54c33d9af18c303ae", size = 387074 } +sdist = { url = "https://files.pythonhosted.org/packages/e2/ac/9f7010c8b050d80b64bfddcc09ef4a4450ae4369940d1b01fa13f5d083de/huggingface_hub-0.29.0.tar.gz", hash = "sha256:64034c852be270cac16c5743fe1f659b14515a9de6342d6f42cbb2ede191fc80", size = 389753 } wheels = [ - { url = "https://files.pythonhosted.org/packages/ea/da/6c2bea5327b640920267d3bf2c9fc114cfbd0a5de234d81cda80cc9e33c8/huggingface_hub-0.28.1-py3-none-any.whl", hash = "sha256:aa6b9a3ffdae939b72c464dbb0d7f99f56e649b55c3d52406f49e0a5a620c0a7", size = 464068 }, + { url = "https://files.pythonhosted.org/packages/2a/4d/8092df2cb0cafa9fcaf691db851b2fccfe9cad4048e081436bbbdf56e4e1/huggingface_hub-0.29.0-py3-none-any.whl", hash = "sha256:c02daa0b6bafbdacb1320fdfd1dc7151d0940825c88c4ef89837fdb1f6ea0afe", size = 468012 }, ] [[package]] @@ -994,7 +994,7 @@ wheels = [ [[package]] name = "lm-format-enforcer" -version = "0.10.9" +version = "0.10.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "interegular" }, @@ -1002,9 +1002,9 @@ dependencies = [ { name = "pydantic" }, { name = "pyyaml" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/73/5d/401ffb7a8895e0f3206345e96c52b428c81e4a2af049d426023cb9cb0cdb/lm_format_enforcer-0.10.9.tar.gz", hash = "sha256:3e0bfeaf9fac9f69c8947da554db9a19a76d0be6e85075055f2c70d0aca420da", size = 39713 } +sdist = { url = "https://files.pythonhosted.org/packages/9d/3f/1ec9e91208a2b8af28ef2caf096e70446d7b3c7218c891fffa899608bf08/lm_format_enforcer-0.10.10.tar.gz", hash = "sha256:b1ff9530ccf73097e35bded94737677c9768a235d74b26af8cd25414efdf85f5", size = 39393 } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/01/e78fdf09de2b4e7750a402eaa4f6783c7215ededd4bc6fe4a3f6d69c49da/lm_format_enforcer-0.10.9-py3-none-any.whl", hash = "sha256:6f3602d3470f54b3ba10d356ea34cc136afbd13394a360949dd8d943a2f2471e", size = 43940 }, + { url = "https://files.pythonhosted.org/packages/32/55/9b91312b7b59903ffa2d1c4310cbeecfea0f8e8e12b154d7ad1d093d0b03/lm_format_enforcer-0.10.10-py3-none-any.whl", hash = "sha256:c5e4330c717780b046c77f46699f8a668cb2b806da540c0127da942538d13695", size = 44231 }, ] [[package]] @@ -1362,7 +1362,7 @@ wheels = [ [[package]] name = "openai" -version = "1.63.0" +version = "1.63.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1374,9 +1374,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4f/32/2049e973a646801df425aecdf88c6504ca878bdb3951fe12076fc30f2977/openai-1.63.0.tar.gz", hash = "sha256:597d7a1b35b113e5a09fcb953bdb1eef44f404a39985f3d7573b3ab09221fd66", size = 356710 } +sdist = { url = "https://files.pythonhosted.org/packages/e6/1c/11b520deb71f9ea54ced3c52cd6a5f7131215deba63ad07f23982e328141/openai-1.63.2.tar.gz", hash = "sha256:aeabeec984a7d2957b4928ceaa339e2ead19c61cfcf35ae62b7c363368d26360", size = 356902 } wheels = [ - { url = "https://files.pythonhosted.org/packages/67/a0/e1fe4e87218639fc0a0927da5266c2978eaa0e2eb5437479ee64a11535bb/openai-1.63.0-py3-none-any.whl", hash = "sha256:a664dfc78f0a05ca46c3e21f344f840cf6bf7174f13cfa9de214ed28bfca1dda", size = 472282 }, + { url = "https://files.pythonhosted.org/packages/15/64/db3462b358072387b8e93e6e6a38d3c741a17b4a84171ef01d6c85c63f25/openai-1.63.2-py3-none-any.whl", hash = "sha256:1f38b27b5a40814c2b7d8759ec78110df58c4a614c25f182809ca52b080ff4d4", size = 472282 }, ] [[package]] @@ -2577,14 +2577,14 @@ wheels = [ [[package]] name = "sphinxcontrib-video" -version = "0.4.0" +version = "0.4.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sphinx" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c7/58/b41664ea7522e886fb33c85a562fe05fc44e1e53bc59da7466d4d7b65787/sphinxcontrib_video-0.4.0.tar.gz", hash = "sha256:1052553faf5f0e255e5e292fae3f5f2fdd295f8a80745d649bfcdbcb12581a69", size = 11324 } +sdist = { url = "https://files.pythonhosted.org/packages/16/48/063e167b6e692bc84bbad74df30bcb27e460a7c620af7824729db8dba606/sphinxcontrib_video-0.4.1.tar.gz", hash = "sha256:75a033e71b7de124cc5902430b7ba818a1c6c377be6401d07e9f2329a95d5ca4", size = 11362 } wheels = [ - { url = "https://files.pythonhosted.org/packages/b3/d5/fa5544847af0e9d335dfa6ece10860abf61b8305365fbb2afe4e9f396b04/sphinxcontrib_video-0.4.0-py3-none-any.whl", hash = "sha256:b94212a6a3489f399ab8287db01536cdd018b5410bbf78d0685db96777ce44e8", size = 10045 }, + { url = "https://files.pythonhosted.org/packages/5d/8b/a0271fe65357860ccc52168181891e9fc9d354bfdc9be273e6a77b84f905/sphinxcontrib_video-0.4.1-py3-none-any.whl", hash = "sha256:d63ec68983dac36960557973281a616b5d9e68838369763313fc80533b1ad774", size = 10066 }, ] [[package]] @@ -2950,61 +2950,61 @@ wheels = [ [[package]] name = "websockets" -version = "14.2" +version = "15.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/94/54/8359678c726243d19fae38ca14a334e740782336c9f19700858c4eb64a1e/websockets-14.2.tar.gz", hash = "sha256:5059ed9c54945efb321f097084b4c7e52c246f2c869815876a69d1efc4ad6eb5", size = 164394 } +sdist = { url = "https://files.pythonhosted.org/packages/2e/7a/8bc4d15af7ff30f7ba34f9a172063bfcee9f5001d7cef04bee800a658f33/websockets-15.0.tar.gz", hash = "sha256:ca36151289a15b39d8d683fd8b7abbe26fc50be311066c5f8dcf3cb8cee107ab", size = 175574 } wheels = [ - { url = "https://files.pythonhosted.org/packages/28/fa/76607eb7dcec27b2d18d63f60a32e60e2b8629780f343bb83a4dbb9f4350/websockets-14.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:e8179f95323b9ab1c11723e5d91a89403903f7b001828161b480a7810b334885", size = 163089 }, - { url = "https://files.pythonhosted.org/packages/9e/00/ad2246b5030575b79e7af0721810fdaecaf94c4b2625842ef7a756fa06dd/websockets-14.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0d8c3e2cdb38f31d8bd7d9d28908005f6fa9def3324edb9bf336d7e4266fd397", size = 160741 }, - { url = "https://files.pythonhosted.org/packages/72/f7/60f10924d333a28a1ff3fcdec85acf226281331bdabe9ad74947e1b7fc0a/websockets-14.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:714a9b682deb4339d39ffa674f7b674230227d981a37d5d174a4a83e3978a610", size = 160996 }, - { url = "https://files.pythonhosted.org/packages/63/7c/c655789cf78648c01ac6ecbe2d6c18f91b75bdc263ffee4d08ce628d12f0/websockets-14.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f2e53c72052f2596fb792a7acd9704cbc549bf70fcde8a99e899311455974ca3", size = 169974 }, - { url = "https://files.pythonhosted.org/packages/fb/5b/013ed8b4611857ac92ac631079c08d9715b388bd1d88ec62e245f87a39df/websockets-14.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e3fbd68850c837e57373d95c8fe352203a512b6e49eaae4c2f4088ef8cf21980", size = 168985 }, - { url = "https://files.pythonhosted.org/packages/cd/33/aa3e32fd0df213a5a442310754fe3f89dd87a0b8e5b4e11e0991dd3bcc50/websockets-14.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b27ece32f63150c268593d5fdb82819584831a83a3f5809b7521df0685cd5d8", size = 169297 }, - { url = "https://files.pythonhosted.org/packages/93/17/dae0174883d6399f57853ac44abf5f228eaba86d98d160f390ffabc19b6e/websockets-14.2-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:4daa0faea5424d8713142b33825fff03c736f781690d90652d2c8b053345b0e7", size = 169677 }, - { url = "https://files.pythonhosted.org/packages/42/e2/0375af7ac00169b98647c804651c515054b34977b6c1354f1458e4116c1e/websockets-14.2-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:bc63cee8596a6ec84d9753fd0fcfa0452ee12f317afe4beae6b157f0070c6c7f", size = 169089 }, - { url = "https://files.pythonhosted.org/packages/73/8d/80f71d2a351a44b602859af65261d3dde3a0ce4e76cf9383738a949e0cc3/websockets-14.2-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:7a570862c325af2111343cc9b0257b7119b904823c675b22d4ac547163088d0d", size = 169026 }, - { url = "https://files.pythonhosted.org/packages/48/97/173b1fa6052223e52bb4054a141433ad74931d94c575e04b654200b98ca4/websockets-14.2-cp310-cp310-win32.whl", hash = "sha256:75862126b3d2d505e895893e3deac0a9339ce750bd27b4ba515f008b5acf832d", size = 163967 }, - { url = "https://files.pythonhosted.org/packages/c0/5b/2fcf60f38252a4562b28b66077e0d2b48f91fef645d5f78874cd1dec807b/websockets-14.2-cp310-cp310-win_amd64.whl", hash = "sha256:cc45afb9c9b2dc0852d5c8b5321759cf825f82a31bfaf506b65bf4668c96f8b2", size = 164413 }, - { url = "https://files.pythonhosted.org/packages/15/b6/504695fb9a33df0ca56d157f5985660b5fc5b4bf8c78f121578d2d653392/websockets-14.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3bdc8c692c866ce5fefcaf07d2b55c91d6922ac397e031ef9b774e5b9ea42166", size = 163088 }, - { url = "https://files.pythonhosted.org/packages/81/26/ebfb8f6abe963c795122439c6433c4ae1e061aaedfc7eff32d09394afbae/websockets-14.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c93215fac5dadc63e51bcc6dceca72e72267c11def401d6668622b47675b097f", size = 160745 }, - { url = "https://files.pythonhosted.org/packages/a1/c6/1435ad6f6dcbff80bb95e8986704c3174da8866ddb751184046f5c139ef6/websockets-14.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1c9b6535c0e2cf8a6bf938064fb754aaceb1e6a4a51a80d884cd5db569886910", size = 160995 }, - { url = "https://files.pythonhosted.org/packages/96/63/900c27cfe8be1a1f2433fc77cd46771cf26ba57e6bdc7cf9e63644a61863/websockets-14.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0a52a6d7cf6938e04e9dceb949d35fbdf58ac14deea26e685ab6368e73744e4c", size = 170543 }, - { url = "https://files.pythonhosted.org/packages/00/8b/bec2bdba92af0762d42d4410593c1d7d28e9bfd952c97a3729df603dc6ea/websockets-14.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9f05702e93203a6ff5226e21d9b40c037761b2cfb637187c9802c10f58e40473", size = 169546 }, - { url = "https://files.pythonhosted.org/packages/6b/a9/37531cb5b994f12a57dec3da2200ef7aadffef82d888a4c29a0d781568e4/websockets-14.2-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22441c81a6748a53bfcb98951d58d1af0661ab47a536af08920d129b4d1c3473", size = 169911 }, - { url = "https://files.pythonhosted.org/packages/60/d5/a6eadba2ed9f7e65d677fec539ab14a9b83de2b484ab5fe15d3d6d208c28/websockets-14.2-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:efd9b868d78b194790e6236d9cbc46d68aba4b75b22497eb4ab64fa640c3af56", size = 170183 }, - { url = "https://files.pythonhosted.org/packages/76/57/a338ccb00d1df881c1d1ee1f2a20c9c1b5b29b51e9e0191ee515d254fea6/websockets-14.2-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1a5a20d5843886d34ff8c57424cc65a1deda4375729cbca4cb6b3353f3ce4142", size = 169623 }, - { url = "https://files.pythonhosted.org/packages/64/22/e5f7c33db0cb2c1d03b79fd60d189a1da044e2661f5fd01d629451e1db89/websockets-14.2-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:34277a29f5303d54ec6468fb525d99c99938607bc96b8d72d675dee2b9f5bf1d", size = 169583 }, - { url = "https://files.pythonhosted.org/packages/aa/2e/2b4662237060063a22e5fc40d46300a07142afe30302b634b4eebd717c07/websockets-14.2-cp311-cp311-win32.whl", hash = "sha256:02687db35dbc7d25fd541a602b5f8e451a238ffa033030b172ff86a93cb5dc2a", size = 163969 }, - { url = "https://files.pythonhosted.org/packages/94/a5/0cda64e1851e73fc1ecdae6f42487babb06e55cb2f0dc8904b81d8ef6857/websockets-14.2-cp311-cp311-win_amd64.whl", hash = "sha256:862e9967b46c07d4dcd2532e9e8e3c2825e004ffbf91a5ef9dde519ee2effb0b", size = 164408 }, - { url = "https://files.pythonhosted.org/packages/c1/81/04f7a397653dc8bec94ddc071f34833e8b99b13ef1a3804c149d59f92c18/websockets-14.2-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1f20522e624d7ffbdbe259c6b6a65d73c895045f76a93719aa10cd93b3de100c", size = 163096 }, - { url = "https://files.pythonhosted.org/packages/ec/c5/de30e88557e4d70988ed4d2eabd73fd3e1e52456b9f3a4e9564d86353b6d/websockets-14.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:647b573f7d3ada919fd60e64d533409a79dcf1ea21daeb4542d1d996519ca967", size = 160758 }, - { url = "https://files.pythonhosted.org/packages/e5/8c/d130d668781f2c77d106c007b6c6c1d9db68239107c41ba109f09e6c218a/websockets-14.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6af99a38e49f66be5a64b1e890208ad026cda49355661549c507152113049990", size = 160995 }, - { url = "https://files.pythonhosted.org/packages/a6/bc/f6678a0ff17246df4f06765e22fc9d98d1b11a258cc50c5968b33d6742a1/websockets-14.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:091ab63dfc8cea748cc22c1db2814eadb77ccbf82829bac6b2fbe3401d548eda", size = 170815 }, - { url = "https://files.pythonhosted.org/packages/d8/b2/8070cb970c2e4122a6ef38bc5b203415fd46460e025652e1ee3f2f43a9a3/websockets-14.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b374e8953ad477d17e4851cdc66d83fdc2db88d9e73abf755c94510ebddceb95", size = 169759 }, - { url = "https://files.pythonhosted.org/packages/81/da/72f7caabd94652e6eb7e92ed2d3da818626e70b4f2b15a854ef60bf501ec/websockets-14.2-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a39d7eceeea35db85b85e1169011bb4321c32e673920ae9c1b6e0978590012a3", size = 170178 }, - { url = "https://files.pythonhosted.org/packages/31/e0/812725b6deca8afd3a08a2e81b3c4c120c17f68c9b84522a520b816cda58/websockets-14.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0a6f3efd47ffd0d12080594f434faf1cd2549b31e54870b8470b28cc1d3817d9", size = 170453 }, - { url = "https://files.pythonhosted.org/packages/66/d3/8275dbc231e5ba9bb0c4f93144394b4194402a7a0c8ffaca5307a58ab5e3/websockets-14.2-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:065ce275e7c4ffb42cb738dd6b20726ac26ac9ad0a2a48e33ca632351a737267", size = 169830 }, - { url = "https://files.pythonhosted.org/packages/a3/ae/e7d1a56755ae15ad5a94e80dd490ad09e345365199600b2629b18ee37bc7/websockets-14.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e9d0e53530ba7b8b5e389c02282f9d2aa47581514bd6049d3a7cffe1385cf5fe", size = 169824 }, - { url = "https://files.pythonhosted.org/packages/b6/32/88ccdd63cb261e77b882e706108d072e4f1c839ed723bf91a3e1f216bf60/websockets-14.2-cp312-cp312-win32.whl", hash = "sha256:20e6dd0984d7ca3037afcb4494e48c74ffb51e8013cac71cf607fffe11df7205", size = 163981 }, - { url = "https://files.pythonhosted.org/packages/b3/7d/32cdb77990b3bdc34a306e0a0f73a1275221e9a66d869f6ff833c95b56ef/websockets-14.2-cp312-cp312-win_amd64.whl", hash = "sha256:44bba1a956c2c9d268bdcdf234d5e5ff4c9b6dc3e300545cbe99af59dda9dcce", size = 164421 }, - { url = "https://files.pythonhosted.org/packages/82/94/4f9b55099a4603ac53c2912e1f043d6c49d23e94dd82a9ce1eb554a90215/websockets-14.2-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6f1372e511c7409a542291bce92d6c83320e02c9cf392223272287ce55bc224e", size = 163102 }, - { url = "https://files.pythonhosted.org/packages/8e/b7/7484905215627909d9a79ae07070057afe477433fdacb59bf608ce86365a/websockets-14.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4da98b72009836179bb596a92297b1a61bb5a830c0e483a7d0766d45070a08ad", size = 160766 }, - { url = "https://files.pythonhosted.org/packages/a3/a4/edb62efc84adb61883c7d2c6ad65181cb087c64252138e12d655989eec05/websockets-14.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f8a86a269759026d2bde227652b87be79f8a734e582debf64c9d302faa1e9f03", size = 160998 }, - { url = "https://files.pythonhosted.org/packages/f5/79/036d320dc894b96af14eac2529967a6fc8b74f03b83c487e7a0e9043d842/websockets-14.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:86cf1aaeca909bf6815ea714d5c5736c8d6dd3a13770e885aafe062ecbd04f1f", size = 170780 }, - { url = "https://files.pythonhosted.org/packages/63/75/5737d21ee4dd7e4b9d487ee044af24a935e36a9ff1e1419d684feedcba71/websockets-14.2-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9b0f6c3ba3b1240f602ebb3971d45b02cc12bd1845466dd783496b3b05783a5", size = 169717 }, - { url = "https://files.pythonhosted.org/packages/2c/3c/bf9b2c396ed86a0b4a92ff4cdaee09753d3ee389be738e92b9bbd0330b64/websockets-14.2-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:669c3e101c246aa85bc8534e495952e2ca208bd87994650b90a23d745902db9a", size = 170155 }, - { url = "https://files.pythonhosted.org/packages/75/2d/83a5aca7247a655b1da5eb0ee73413abd5c3a57fc8b92915805e6033359d/websockets-14.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:eabdb28b972f3729348e632ab08f2a7b616c7e53d5414c12108c29972e655b20", size = 170495 }, - { url = "https://files.pythonhosted.org/packages/79/dd/699238a92761e2f943885e091486378813ac8f43e3c84990bc394c2be93e/websockets-14.2-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2066dc4cbcc19f32c12a5a0e8cc1b7ac734e5b64ac0a325ff8353451c4b15ef2", size = 169880 }, - { url = "https://files.pythonhosted.org/packages/c8/c9/67a8f08923cf55ce61aadda72089e3ed4353a95a3a4bc8bf42082810e580/websockets-14.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:ab95d357cd471df61873dadf66dd05dd4709cae001dd6342edafc8dc6382f307", size = 169856 }, - { url = "https://files.pythonhosted.org/packages/17/b1/1ffdb2680c64e9c3921d99db460546194c40d4acbef999a18c37aa4d58a3/websockets-14.2-cp313-cp313-win32.whl", hash = "sha256:a9e72fb63e5f3feacdcf5b4ff53199ec8c18d66e325c34ee4c551ca748623bbc", size = 163974 }, - { url = "https://files.pythonhosted.org/packages/14/13/8b7fc4cb551b9cfd9890f0fd66e53c18a06240319915533b033a56a3d520/websockets-14.2-cp313-cp313-win_amd64.whl", hash = "sha256:b439ea828c4ba99bb3176dc8d9b933392a2413c0f6b149fdcba48393f573377f", size = 164420 }, - { url = "https://files.pythonhosted.org/packages/10/3d/91d3d2bb1325cd83e8e2c02d0262c7d4426dc8fa0831ef1aa4d6bf2041af/websockets-14.2-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:d7d9cafbccba46e768be8a8ad4635fa3eae1ffac4c6e7cb4eb276ba41297ed29", size = 160773 }, - { url = "https://files.pythonhosted.org/packages/33/7c/cdedadfef7381939577858b1b5718a4ab073adbb584e429dd9d9dc9bfe16/websockets-14.2-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:c76193c1c044bd1e9b3316dcc34b174bbf9664598791e6fb606d8d29000e070c", size = 161007 }, - { url = "https://files.pythonhosted.org/packages/ca/35/7a20a3c450b27c04e50fbbfc3dfb161ed8e827b2a26ae31c4b59b018b8c6/websockets-14.2-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd475a974d5352390baf865309fe37dec6831aafc3014ffac1eea99e84e83fc2", size = 162264 }, - { url = "https://files.pythonhosted.org/packages/e8/9c/e3f9600564b0c813f2448375cf28b47dc42c514344faed3a05d71fb527f9/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2c6c0097a41968b2e2b54ed3424739aab0b762ca92af2379f152c1aef0187e1c", size = 161873 }, - { url = "https://files.pythonhosted.org/packages/3f/37/260f189b16b2b8290d6ae80c9f96d8b34692cf1bb3475df54c38d3deb57d/websockets-14.2-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6d7ff794c8b36bc402f2e07c0b2ceb4a2424147ed4785ff03e2a7af03711d60a", size = 161818 }, - { url = "https://files.pythonhosted.org/packages/ff/1e/e47dedac8bf7140e59aa6a679e850c4df9610ae844d71b6015263ddea37b/websockets-14.2-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:dec254fcabc7bd488dab64846f588fc5b6fe0d78f641180030f8ea27b76d72c3", size = 164465 }, - { url = "https://files.pythonhosted.org/packages/7b/c8/d529f8a32ce40d98309f4470780631e971a5a842b60aec864833b3615786/websockets-14.2-py3-none-any.whl", hash = "sha256:7a6ceec4ea84469f15cf15807a747e9efe57e369c384fa86e022b3bea679b79b", size = 157416 }, + { url = "https://files.pythonhosted.org/packages/3d/f1/b20cc4c1ff84911c791f36fa511a78203836bb4d603f56290de08c067437/websockets-15.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5e6ee18a53dd5743e6155b8ff7e8e477c25b29b440f87f65be8165275c87fef0", size = 174701 }, + { url = "https://files.pythonhosted.org/packages/f9/e8/4de59ee85ec86052ca574f4e5327ef948e4f77757d3c9c1503f5a0e9c039/websockets-15.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ee06405ea2e67366a661ed313e14cf2a86e84142a3462852eb96348f7219cee3", size = 172358 }, + { url = "https://files.pythonhosted.org/packages/2f/ea/b0f95815cdc83d61b1a895858671c6af38a76c23f3ea5d91e2ba11bbedc7/websockets-15.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:8711682a629bbcaf492f5e0af72d378e976ea1d127a2d47584fa1c2c080b436b", size = 172610 }, + { url = "https://files.pythonhosted.org/packages/09/ed/c5d8f1f296f475c00611a40eff6a952248785efb125f91a0b29575f36ba6/websockets-15.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94c4a9b01eede952442c088d415861b0cf2053cbd696b863f6d5022d4e4e2453", size = 181579 }, + { url = "https://files.pythonhosted.org/packages/b7/fc/2444b5ae792d92179f20cec53475bcc25d1d7f00a2be9947de9837ef230a/websockets-15.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:45535fead66e873f411c1d3cf0d3e175e66f4dd83c4f59d707d5b3e4c56541c4", size = 180588 }, + { url = "https://files.pythonhosted.org/packages/ff/b5/0945a31562d351cff26d76a2ae9a4ba4536e698aa059a4262afd793b2a1d/websockets-15.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e389efe46ccb25a1f93d08c7a74e8123a2517f7b7458f043bd7529d1a63ffeb", size = 180902 }, + { url = "https://files.pythonhosted.org/packages/b6/7c/e9d844b87754bc83b294cc1c695cbc6c5d42e329b85d2bf2d7bb9554d09c/websockets-15.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:67a04754d121ea5ca39ddedc3f77071651fb5b0bc6b973c71c515415b44ed9c5", size = 181282 }, + { url = "https://files.pythonhosted.org/packages/9e/6c/6a5d3272f494fa2fb4806b896ecb312bd6c72bab632df4ace19946c079dc/websockets-15.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:bd66b4865c8b853b8cca7379afb692fc7f52cf898786537dfb5e5e2d64f0a47f", size = 180694 }, + { url = "https://files.pythonhosted.org/packages/b2/32/1fb4b62c2ec2c9844d4ddaa4021d993552c7c493a0acdcec95551679d501/websockets-15.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:a4cc73a6ae0a6751b76e69cece9d0311f054da9b22df6a12f2c53111735657c8", size = 180631 }, + { url = "https://files.pythonhosted.org/packages/e4/9b/5ef1ddb8857ce894217bdd9572ad98c1cef20d8f9f0f43823b782b7ded6b/websockets-15.0-cp310-cp310-win32.whl", hash = "sha256:89da58e4005e153b03fe8b8794330e3f6a9774ee9e1c3bd5bc52eb098c3b0c4f", size = 175664 }, + { url = "https://files.pythonhosted.org/packages/29/63/c320572ccf813ed2bc3058a0e0291ee95eb258dc5e6b3446ca45dc1af0fd/websockets-15.0-cp310-cp310-win_amd64.whl", hash = "sha256:4ff380aabd7a74a42a760ee76c68826a8f417ceb6ea415bd574a035a111fd133", size = 176109 }, + { url = "https://files.pythonhosted.org/packages/ee/16/81a7403c8c0a33383de647e89c07824ea6a654e3877d6ff402cbae298cb8/websockets-15.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:dd24c4d256558429aeeb8d6c24ebad4e982ac52c50bc3670ae8646c181263965", size = 174702 }, + { url = "https://files.pythonhosted.org/packages/ef/40/4629202386a3bf1195db9fe41baeb1d6dfd8d72e651d9592d81dae7fdc7c/websockets-15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f83eca8cbfd168e424dfa3b3b5c955d6c281e8fc09feb9d870886ff8d03683c7", size = 172359 }, + { url = "https://files.pythonhosted.org/packages/7b/33/dfb650e822bc7912d8c542c452497867af91dec81e7b5bf96aca5b419d58/websockets-15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4095a1f2093002c2208becf6f9a178b336b7572512ee0a1179731acb7788e8ad", size = 172604 }, + { url = "https://files.pythonhosted.org/packages/2e/52/666743114513fcffd43ee5df261a1eb5d41f8e9861b7a190b730732c19ba/websockets-15.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb915101dfbf318486364ce85662bb7b020840f68138014972c08331458d41f3", size = 182145 }, + { url = "https://files.pythonhosted.org/packages/9c/63/5273f146b13aa4a057a95ab0855d9990f3a1ced63693f4365135d1abfacc/websockets-15.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:45d464622314973d78f364689d5dbb9144e559f93dca11b11af3f2480b5034e1", size = 181152 }, + { url = "https://files.pythonhosted.org/packages/0f/ae/075697f3f97de7c26b73ae96d952e13fa36393e0db3f028540b28954e0a9/websockets-15.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ace960769d60037ca9625b4c578a6f28a14301bd2a1ff13bb00e824ac9f73e55", size = 181523 }, + { url = "https://files.pythonhosted.org/packages/25/87/06d091bbcbe01903bed3dad3bb4a1a3c516f61e611ec31fffb28abe4974b/websockets-15.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c7cd4b1015d2f60dfe539ee6c95bc968d5d5fad92ab01bb5501a77393da4f596", size = 181791 }, + { url = "https://files.pythonhosted.org/packages/77/08/5063b6cc1b2aa1fba2ee3b578b777db22fde7145f121d07fd878811e983b/websockets-15.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:4f7290295794b5dec470867c7baa4a14182b9732603fd0caf2a5bf1dc3ccabf3", size = 181231 }, + { url = "https://files.pythonhosted.org/packages/86/ff/af23084df0a7405bb2add12add8c17d6192a8de9480f1b90d12352ba2b7d/websockets-15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:3abd670ca7ce230d5a624fd3d55e055215d8d9b723adee0a348352f5d8d12ff4", size = 181191 }, + { url = "https://files.pythonhosted.org/packages/21/ce/b2bdfcf49201dee0b899edc6a814755763ec03d74f2714923d38453a9e8d/websockets-15.0-cp311-cp311-win32.whl", hash = "sha256:110a847085246ab8d4d119632145224d6b49e406c64f1bbeed45c6f05097b680", size = 175666 }, + { url = "https://files.pythonhosted.org/packages/8d/7b/444edcd5365538c226b631897975a65bbf5ccf27c77102e17d8f12a306ea/websockets-15.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d7bbbe2cd6ed80aceef2a14e9f1c1b61683194c216472ed5ff33b700e784e37", size = 176105 }, + { url = "https://files.pythonhosted.org/packages/22/1e/92c4547d7b2a93f848aedaf37e9054111bc00dc11bff4385ca3f80dbb412/websockets-15.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:cccc18077acd34c8072578394ec79563664b1c205f7a86a62e94fafc7b59001f", size = 174709 }, + { url = "https://files.pythonhosted.org/packages/9f/37/eae4830a28061ba552516d84478686b637cd9e57d6a90b45ad69e89cb0af/websockets-15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d4c22992e24f12de340ca5f824121a5b3e1a37ad4360b4e1aaf15e9d1c42582d", size = 172372 }, + { url = "https://files.pythonhosted.org/packages/46/2f/b409f8b8aa9328d5a47f7a301a43319d540d70cf036d1e6443675978a988/websockets-15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1206432cc6c644f6fc03374b264c5ff805d980311563202ed7fef91a38906276", size = 172607 }, + { url = "https://files.pythonhosted.org/packages/d6/81/d7e2e4542d4b4df849b0110df1b1f94f2647b71ab4b65d672090931ad2bb/websockets-15.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d3cc75ef3e17490042c47e0523aee1bcc4eacd2482796107fd59dd1100a44bc", size = 182422 }, + { url = "https://files.pythonhosted.org/packages/b6/91/3b303160938d123eea97f58be363f7dbec76e8c59d587e07b5bc257dd584/websockets-15.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b89504227a5311610e4be16071465885a0a3d6b0e82e305ef46d9b064ce5fb72", size = 181362 }, + { url = "https://files.pythonhosted.org/packages/f2/8b/df6807f1ca339c567aba9a7ab03bfdb9a833f625e8d2b4fc7529e4c701de/websockets-15.0-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56e3efe356416bc67a8e093607315951d76910f03d2b3ad49c4ade9207bf710d", size = 181787 }, + { url = "https://files.pythonhosted.org/packages/21/37/e6d3d5ebb0ebcaf98ae84904205c9dcaf3e0fe93e65000b9f08631ed7309/websockets-15.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f2205cdb444a42a7919690238fb5979a05439b9dbb73dd47c863d39640d85ab", size = 182058 }, + { url = "https://files.pythonhosted.org/packages/c9/df/6aca296f2be4c638ad20908bb3d7c94ce7afc8d9b4b2b0780d1fc59b359c/websockets-15.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:aea01f40995fa0945c020228ab919b8dfc93fc8a9f2d3d705ab5b793f32d9e99", size = 181434 }, + { url = "https://files.pythonhosted.org/packages/88/f1/75717a982bab39bbe63c83f9df0e7753e5c98bab907eb4fb5d97fe5c8c11/websockets-15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a9f8e33747b1332db11cf7fcf4a9512bef9748cb5eb4d3f7fbc8c30d75dc6ffc", size = 181431 }, + { url = "https://files.pythonhosted.org/packages/e7/15/cee9e63ed9ac5bfc1a3ae8fc6c02c41745023c21eed622eef142d8fdd749/websockets-15.0-cp312-cp312-win32.whl", hash = "sha256:32e02a2d83f4954aa8c17e03fe8ec6962432c39aca4be7e8ee346b05a3476904", size = 175678 }, + { url = "https://files.pythonhosted.org/packages/4e/00/993974c60f40faabb725d4dbae8b072ef73b4c4454bd261d3b1d34ace41f/websockets-15.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc02b159b65c05f2ed9ec176b715b66918a674bd4daed48a9a7a590dd4be1aa", size = 176119 }, + { url = "https://files.pythonhosted.org/packages/12/23/be28dc1023707ac51768f848d28a946443041a348ee3a54abdf9f6283372/websockets-15.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d2244d8ab24374bed366f9ff206e2619345f9cd7fe79aad5225f53faac28b6b1", size = 174714 }, + { url = "https://files.pythonhosted.org/packages/8f/ff/02b5e9fbb078e7666bf3d25c18c69b499747a12f3e7f2776063ef3fb7061/websockets-15.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3a302241fbe825a3e4fe07666a2ab513edfdc6d43ce24b79691b45115273b5e7", size = 172374 }, + { url = "https://files.pythonhosted.org/packages/8e/61/901c8d4698e0477eff4c3c664d53f898b601fa83af4ce81946650ec2a4cb/websockets-15.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:10552fed076757a70ba2c18edcbc601c7637b30cdfe8c24b65171e824c7d6081", size = 172605 }, + { url = "https://files.pythonhosted.org/packages/d2/4b/dc47601a80dff317aecf8da7b4ab278d11d3494b2c373b493e4887561f90/websockets-15.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c53f97032b87a406044a1c33d1e9290cc38b117a8062e8a8b285175d7e2f99c9", size = 182380 }, + { url = "https://files.pythonhosted.org/packages/83/f7/b155d2b38f05ed47a0b8de1c9ea245fcd7fc625d89f35a37eccba34b42de/websockets-15.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1caf951110ca757b8ad9c4974f5cac7b8413004d2f29707e4d03a65d54cedf2b", size = 181325 }, + { url = "https://files.pythonhosted.org/packages/d3/ff/040a20c01c294695cac0e361caf86f33347acc38f164f6d2be1d3e007d9f/websockets-15.0-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8bf1ab71f9f23b0a1d52ec1682a3907e0c208c12fef9c3e99d2b80166b17905f", size = 181763 }, + { url = "https://files.pythonhosted.org/packages/cb/6a/af23e93678fda8341ac8775e85123425e45c608389d3514863c702896ea5/websockets-15.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:bfcd3acc1a81f106abac6afd42327d2cf1e77ec905ae11dc1d9142a006a496b6", size = 182097 }, + { url = "https://files.pythonhosted.org/packages/7e/3e/1069e159c30129dc03c01513b5830237e576f47cedb888777dd885cae583/websockets-15.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:c8c5c8e1bac05ef3c23722e591ef4f688f528235e2480f157a9cfe0a19081375", size = 181485 }, + { url = "https://files.pythonhosted.org/packages/9a/a7/c91c47103f1cd941b576bbc452601e9e01f67d5c9be3e0a9abe726491ab5/websockets-15.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:86bfb52a9cfbcc09aba2b71388b0a20ea5c52b6517c0b2e316222435a8cdab72", size = 181466 }, + { url = "https://files.pythonhosted.org/packages/16/32/a4ca6e3d56c24aac46b0cf5c03b841379f6409d07fc2044b244f90f54105/websockets-15.0-cp313-cp313-win32.whl", hash = "sha256:26ba70fed190708551c19a360f9d7eca8e8c0f615d19a574292b7229e0ae324c", size = 175673 }, + { url = "https://files.pythonhosted.org/packages/c0/31/25a417a23e985b61ffa5544f9facfe4a118cb64d664c886f1244a8baeca5/websockets-15.0-cp313-cp313-win_amd64.whl", hash = "sha256:ae721bcc8e69846af00b7a77a220614d9b2ec57d25017a6bbde3a99473e41ce8", size = 176115 }, + { url = "https://files.pythonhosted.org/packages/42/52/359467c7ca12721a04520da9ba9fc29da2cd176c30992f6f81fa881bb3e5/websockets-15.0-pp310-pypy310_pp73-macosx_10_15_x86_64.whl", hash = "sha256:b499caef4bca9cbd0bd23cd3386f5113ee7378094a3cb613a2fa543260fe9506", size = 172384 }, + { url = "https://files.pythonhosted.org/packages/7c/ff/36fd8a45fac404d8f109e03ca06328f49847d71c0c048414c76bb2db91c4/websockets-15.0-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:17f2854c6bd9ee008c4b270f7010fe2da6c16eac5724a175e75010aacd905b31", size = 172616 }, + { url = "https://files.pythonhosted.org/packages/b1/a8/65496a87984815e2837835d5ac3c9f81ea82031036877e8f80953c59dbd9/websockets-15.0-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:89f72524033abbfde880ad338fd3c2c16e31ae232323ebdfbc745cbb1b3dcc03", size = 173871 }, + { url = "https://files.pythonhosted.org/packages/23/89/9441e1e0818d46fe22d78b3e5c8fe2316516211330e138231c90dce5559e/websockets-15.0-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1657a9eecb29d7838e3b415458cc494e6d1b194f7ac73a34aa55c6fb6c72d1f3", size = 173477 }, + { url = "https://files.pythonhosted.org/packages/2f/1b/80460b3ac9795ef7bbaa074c603d64e009dbb2ceb11008416efab0dcc811/websockets-15.0-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e413352a921f5ad5d66f9e2869b977e88d5103fc528b6deb8423028a2befd842", size = 173425 }, + { url = "https://files.pythonhosted.org/packages/56/d1/8da7e733ed266f342e8c544c3b8338449de9b860d85d9a0bfd4fe1857d6e/websockets-15.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:8561c48b0090993e3b2a54db480cab1d23eb2c5735067213bb90f402806339f5", size = 176160 }, + { url = "https://files.pythonhosted.org/packages/e8/b2/31eec524b53f01cd8343f10a8e429730c52c1849941d1f530f8253b6d934/websockets-15.0-py3-none-any.whl", hash = "sha256:51ffd53c53c4442415b613497a34ba0aa7b99ac07f1e4a62db5dcd640ae6c3c3", size = 169023 }, ] [[package]] From ab54b8cd582dcaf7d67e063168f0c08ef3f18c0b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 13:21:28 -0800 Subject: [PATCH 27/45] feat(providers): support non-llama models for inference providers (#1200) This PR begins the process of supporting non-llama models within Llama Stack. We start simple by adding support for this functionality within a few existing providers: fireworks, together and ollama. ## Test Plan ```bash LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/inference/test_text_inference.py \ --inference-model accounts/fireworks/models/phi-3-vision-128k-instruct ``` ^ this passes most of the tests but as expected fails the tool calling related tests since they are very specific to Llama models ``` inference/test_text_inference.py::test_text_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_completion_log_probs_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_completion_log_probs_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_text_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct-completion-01] PASSED inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet do humans live on?-Earth] PASSED inference/test_text_inference.py::test_text_chat_completion_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-Which planet has rings around it with a name starting w ith letter S?-Saturn] PASSED inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What's the name of the Sun in latin?-Sol] PASSED inference/test_text_inference.py::test_text_chat_completion_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct-What is the name of the US captial?-Washington] PASSED inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_required[accounts/fireworks/models/phi-3-vision-128k-instruct] FAILED inference/test_text_inference.py::test_text_chat_completion_with_tool_choice_none[accounts/fireworks/models/phi-3-vision-128k-instruct] PASSED inference/test_text_inference.py::test_text_chat_completion_structured_output[accounts/fireworks/models/phi-3-vision-128k-instruct] ERROR inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-True] PASSED inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[accounts/fireworks/models/phi-3-vision-128k-instruct-False] PASSED ``` --- .../remote/inference/fireworks/fireworks.py | 7 +-- .../remote/inference/ollama/ollama.py | 5 +- .../remote/inference/together/together.py | 7 +-- .../utils/inference/model_registry.py | 18 +++--- tests/client-sdk/conftest.py | 59 ++++++++++++++++-- .../inference/test_text_inference.py | 61 ++++++++----------- .../inference/test_vision_inference.py | 20 ++---- 7 files changed, 103 insertions(+), 74 deletions(-) diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index b9b23584b..90fe70cbf 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -209,15 +209,14 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv input_dict = {} media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: input_dict["messages"] = [ await convert_message_to_openai_dict(m, download=True) for m in request.messages ] else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model) - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: assert not media_present, "Fireworks does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 058bbeeee..6fcfd2e99 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -178,8 +178,9 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict = {} media_present = request_has_media(request) + llama_model = self.register_helper.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: contents = [await convert_message_to_openai_dict_for_ollama(m) for m in request.messages] # flatten the list of lists input_dict["messages"] = [item for sublist in contents for item in sublist] @@ -187,7 +188,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): input_dict["raw"] = True input_dict["prompt"] = await chat_completion_request_to_prompt( request, - self.register_helper.get_llama_model(request.model), + llama_model, ) else: assert not media_present, "Ollama does not support media for Completion requests" diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 1fca54bb3..040f04e77 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -203,13 +203,12 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: input_dict = {} media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) if isinstance(request, ChatCompletionRequest): - if media_present: + if media_present or not llama_model: input_dict["messages"] = [await convert_message_to_openai_dict(m) for m in request.messages] else: - input_dict["prompt"] = await chat_completion_request_to_prompt( - request, self.get_llama_model(request.model) - ) + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) else: assert not media_present, "Together does not support media for Completion requests" input_dict["prompt"] = await completion_request_to_prompt(request) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 0882019e3..d9e24662a 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -79,28 +79,28 @@ class ModelRegistryHelper(ModelsProtocolPrivate): provider_resource_id = model.provider_resource_id else: provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + if provider_resource_id: model.provider_resource_id = provider_resource_id else: - if model.metadata.get("llama_model") is None: - raise ValueError( - f"Model '{model.provider_resource_id}' is not available and no llama_model was specified in metadata. " - "Please specify a llama_model in metadata or use a supported model identifier" - ) + llama_model = model.metadata.get("llama_model") + if llama_model is None: + return model + existing_llama_model = self.get_llama_model(model.provider_resource_id) if existing_llama_model: - if existing_llama_model != model.metadata["llama_model"]: + if existing_llama_model != llama_model: raise ValueError( f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" ) else: - if model.metadata["llama_model"] not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + if llama_model not in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: raise ValueError( - f"Invalid llama_model '{model.metadata['llama_model']}' specified in metadata. " + f"Invalid llama_model '{llama_model}' specified in metadata. " f"Must be one of: {', '.join(ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR.keys())}" ) self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[model.metadata["llama_model"]] + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] ) return model diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index efdec6b01..662505590 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -42,28 +42,30 @@ def pytest_addoption(parser): ) parser.addoption( "--inference-model", - action="store", default=TEXT_MODEL, help="Specify the inference model to use for testing", ) parser.addoption( "--vision-inference-model", - action="store", default=VISION_MODEL, help="Specify the vision inference model to use for testing", ) parser.addoption( "--safety-shield", - action="store", default="meta-llama/Llama-Guard-3-1B", help="Specify the safety shield model to use for testing", ) parser.addoption( "--embedding-model", - action="store", - default=TEXT_MODEL, + default=None, help="Specify the embedding model to use for testing", ) + parser.addoption( + "--embedding-dimension", + type=int, + default=384, + help="Output dimensionality of the embedding model to use for testing", + ) @pytest.fixture(scope="session") @@ -78,7 +80,7 @@ def provider_data(): @pytest.fixture(scope="session") -def llama_stack_client(provider_data): +def llama_stack_client(provider_data, text_model_id): if os.environ.get("LLAMA_STACK_CONFIG"): client = LlamaStackAsLibraryClient( get_env_or_fail("LLAMA_STACK_CONFIG"), @@ -95,6 +97,45 @@ def llama_stack_client(provider_data): ) else: raise ValueError("LLAMA_STACK_CONFIG or LLAMA_STACK_BASE_URL must be set") + + return client + + +@pytest.fixture(scope="session") +def inference_provider_type(llama_stack_client): + providers = llama_stack_client.providers.list() + inference_providers = [p for p in providers if p.api == "inference"] + assert len(inference_providers) > 0, "No inference providers found" + return inference_providers[0].provider_type + + +@pytest.fixture(scope="session") +def client_with_models(llama_stack_client, text_model_id, vision_model_id, embedding_model_id, embedding_dimension): + client = llama_stack_client + + providers = [p for p in client.providers.list() if p.api == "inference"] + assert len(providers) > 0, "No inference providers found" + inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] + if text_model_id: + client.models.register(model_id=text_model_id, provider_id=inference_providers[0]) + if vision_model_id: + client.models.register(model_id=vision_model_id, provider_id=inference_providers[0]) + + if embedding_model_id and embedding_dimension: + # try to find a provider that supports embeddings, if sentence-transformers is not available + selected_provider = None + for p in providers: + if p.provider_type == "inline::sentence-transformers": + selected_provider = p + break + + selected_provider = selected_provider or providers[0] + client.models.register( + model_id=embedding_model_id, + provider_id=selected_provider.provider_id, + model_type="embedding", + metadata={"embedding_dimension": embedding_dimension}, + ) return client @@ -117,3 +158,9 @@ def pytest_generate_tests(metafunc): [metafunc.config.getoption("--embedding-model")], scope="session", ) + if "embedding_dimension" in metafunc.fixturenames: + metafunc.parametrize( + "embedding_dimension", + [metafunc.config.getoption("--embedding-dimension")], + scope="session", + ) diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 545325bbe..75d932380 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -28,14 +28,6 @@ def provider_tool_format(inference_provider_type): ) -@pytest.fixture(scope="session") -def inference_provider_type(llama_stack_client): - providers = llama_stack_client.providers.list() - inference_providers = [p for p in providers if p.api == "inference"] - assert len(inference_providers) > 0, "No inference providers found" - return inference_providers[0].provider_type - - @pytest.fixture def get_weather_tool_definition(): return { @@ -50,8 +42,8 @@ def get_weather_tool_definition(): } -def test_text_completion_non_streaming(llama_stack_client, text_model_id): - response = llama_stack_client.inference.completion( +def test_text_completion_non_streaming(client_with_models, text_model_id): + response = client_with_models.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", stream=False, model_id=text_model_id, @@ -63,8 +55,8 @@ def test_text_completion_non_streaming(llama_stack_client, text_model_id): # assert "blue" in response.content.lower().strip() -def test_text_completion_streaming(llama_stack_client, text_model_id): - response = llama_stack_client.inference.completion( +def test_text_completion_streaming(client_with_models, text_model_id): + response = client_with_models.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", stream=True, model_id=text_model_id, @@ -78,11 +70,11 @@ def test_text_completion_streaming(llama_stack_client, text_model_id): assert len(content_str) > 10 -def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, inference_provider_type): +def test_completion_log_probs_non_streaming(client_with_models, text_model_id, inference_provider_type): if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") - response = llama_stack_client.inference.completion( + response = client_with_models.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=False, model_id=text_model_id, @@ -98,11 +90,11 @@ def test_completion_log_probs_non_streaming(llama_stack_client, text_model_id, i assert all(len(logprob.logprobs_by_token) == 1 for logprob in response.logprobs) -def test_completion_log_probs_streaming(llama_stack_client, text_model_id, inference_provider_type): +def test_completion_log_probs_streaming(client_with_models, text_model_id, inference_provider_type): if inference_provider_type not in PROVIDER_LOGPROBS_TOP_K: pytest.xfail(f"{inference_provider_type} doesn't support log probs yet") - response = llama_stack_client.inference.completion( + response = client_with_models.inference.completion( content="Complete the sentence: Micheael Jordan is born in ", stream=True, model_id=text_model_id, @@ -123,7 +115,7 @@ def test_completion_log_probs_streaming(llama_stack_client, text_model_id, infer @pytest.mark.parametrize("test_case", ["completion-01"]) -def test_text_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case): +def test_text_completion_structured_output(client_with_models, text_model_id, test_case): class AnswerFormat(BaseModel): name: str year_born: str @@ -132,7 +124,7 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in tc = TestCase(test_case) user_input = tc["user_input"] - response = llama_stack_client.inference.completion( + response = client_with_models.inference.completion( model_id=text_model_id, content=user_input, stream=False, @@ -161,8 +153,8 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in ), ], ) -def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): - response = llama_stack_client.inference.chat_completion( +def test_text_chat_completion_non_streaming(client_with_models, text_model_id, question, expected): + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ { @@ -184,8 +176,8 @@ def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, q ("What is the name of the US captial?", "Washington"), ], ) -def test_text_chat_completion_streaming(llama_stack_client, text_model_id, question, expected): - response = llama_stack_client.inference.chat_completion( +def test_text_chat_completion_streaming(client_with_models, text_model_id, question, expected): + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[{"role": "user", "content": question}], stream=True, @@ -196,9 +188,9 @@ def test_text_chat_completion_streaming(llama_stack_client, text_model_id, quest def test_text_chat_completion_with_tool_calling_and_non_streaming( - llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -233,9 +225,9 @@ def extract_tool_invocation_content(response): def test_text_chat_completion_with_tool_calling_and_streaming( - llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -251,13 +243,12 @@ def test_text_chat_completion_with_tool_calling_and_streaming( def test_text_chat_completion_with_tool_choice_required( - llama_stack_client, + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format, - inference_provider_type, ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -275,9 +266,9 @@ def test_text_chat_completion_with_tool_choice_required( def test_text_chat_completion_with_tool_choice_none( - llama_stack_client, text_model_id, get_weather_tool_definition, provider_tool_format + client_with_models, text_model_id, get_weather_tool_definition, provider_tool_format ): - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=[ {"role": "system", "content": "You are a helpful assistant."}, @@ -292,7 +283,7 @@ def test_text_chat_completion_with_tool_choice_none( @pytest.mark.parametrize("test_case", ["chat_completion-01"]) -def test_text_chat_completion_structured_output(llama_stack_client, text_model_id, inference_provider_type, test_case): +def test_text_chat_completion_structured_output(client_with_models, text_model_id, test_case): class AnswerFormat(BaseModel): first_name: str last_name: str @@ -301,7 +292,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i tc = TestCase(test_case) - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=text_model_id, messages=tc["messages"], response_format={ @@ -325,7 +316,7 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i False, ], ) -def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming): +def test_text_chat_completion_tool_calling_tools_not_in_request(client_with_models, text_model_id, streaming): # TODO: more dynamic lookup on tool_prompt_format for model family tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" request = { @@ -381,7 +372,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_clie "stream": streaming, } - response = llama_stack_client.inference.chat_completion(**request) + response = client_with_models.inference.chat_completion(**request) if streaming: for chunk in response: diff --git a/tests/client-sdk/inference/test_vision_inference.py b/tests/client-sdk/inference/test_vision_inference.py index b23089747..8fa0d8023 100644 --- a/tests/client-sdk/inference/test_vision_inference.py +++ b/tests/client-sdk/inference/test_vision_inference.py @@ -10,14 +10,6 @@ import pathlib import pytest -@pytest.fixture(scope="session") -def inference_provider_type(llama_stack_client): - providers = llama_stack_client.providers.list() - inference_providers = [p for p in providers if p.api == "inference"] - assert len(inference_providers) > 0, "No inference providers found" - return inference_providers[0].provider_type - - @pytest.fixture def image_path(): return pathlib.Path(__file__).parent / "dog.png" @@ -35,7 +27,7 @@ def base64_image_url(base64_image_data, image_path): return f"data:image/{image_path.suffix[1:]};base64,{base64_image_data}" -def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id): +def test_image_chat_completion_non_streaming(client_with_models, vision_model_id): message = { "role": "user", "content": [ @@ -53,7 +45,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id }, ], } - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=False, @@ -63,7 +55,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id assert any(expected in message_content for expected in {"dog", "puppy", "pup"}) -def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): +def test_image_chat_completion_streaming(client_with_models, vision_model_id): message = { "role": "user", "content": [ @@ -81,7 +73,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): }, ], } - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=True, @@ -94,7 +86,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): @pytest.mark.parametrize("type_", ["url", "data"]) -def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base64_image_data, base64_image_url, type_): +def test_image_chat_completion_base64(client_with_models, vision_model_id, base64_image_data, base64_image_url, type_): image_spec = { "url": { "type": "image", @@ -122,7 +114,7 @@ def test_image_chat_completion_base64(llama_stack_client, vision_model_id, base6 }, ], } - response = llama_stack_client.inference.chat_completion( + response = client_with_models.inference.chat_completion( model_id=vision_model_id, messages=[message], stream=False, From 182608d4bf19aa155fb5b29987874fa71579ccc3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 14:24:09 -0800 Subject: [PATCH 28/45] better test naming --- tests/client-sdk/conftest.py | 61 ++++++++++++++++++++++++------------ 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 662505590..13dee0ba3 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -139,28 +139,49 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed return client +MODEL_SHORT_IDS = { + "meta-llama/Llama-3.1-8B-Instruct": "8B", + "meta-llama/Llama-3.2-11B-Vision-Instruct": "11B", + "all-MiniLM-L6-v2": "MiniLM", +} + + +def get_short_id(value): + return MODEL_SHORT_IDS.get(value, value) + + def pytest_generate_tests(metafunc): + params = [] + values = [] + id_parts = [] + if "text_model_id" in metafunc.fixturenames: - metafunc.parametrize( - "text_model_id", - [metafunc.config.getoption("--inference-model")], - scope="session", - ) + params.append("text_model_id") + val = metafunc.config.getoption("--inference-model") + values.append(val) + id_parts.append(f"txt={get_short_id(val)}") + if "vision_model_id" in metafunc.fixturenames: - metafunc.parametrize( - "vision_model_id", - [metafunc.config.getoption("--vision-inference-model")], - scope="session", - ) + params.append("vision_model_id") + val = metafunc.config.getoption("--vision-inference-model") + values.append(val) + id_parts.append(f"vis={get_short_id(val)}") + if "embedding_model_id" in metafunc.fixturenames: - metafunc.parametrize( - "embedding_model_id", - [metafunc.config.getoption("--embedding-model")], - scope="session", - ) + params.append("embedding_model_id") + val = metafunc.config.getoption("--embedding-model") + values.append(val) + if val is not None: + id_parts.append(f"emb={get_short_id(val)}") + if "embedding_dimension" in metafunc.fixturenames: - metafunc.parametrize( - "embedding_dimension", - [metafunc.config.getoption("--embedding-dimension")], - scope="session", - ) + params.append("embedding_dimension") + val = metafunc.config.getoption("--embedding-dimension") + values.append(val) + if val != 384: + id_parts.append(f"dim={val}") + + if params: + # Create a single test ID string + test_id = ":".join(id_parts) + metafunc.parametrize(params, [values], scope="session", ids=[test_id]) From e7d261ef4ad9c0a672611a66b6bdaf52aacbeac4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 15:10:10 -0800 Subject: [PATCH 29/45] Fix test infra, sentence embeddings mixin --- llama_stack/distribution/library_client.py | 11 +++++----- .../utils/inference/embedding_mixin.py | 3 ++- tests/client-sdk/vector_io/conftest.py | 22 ------------------- tests/client-sdk/vector_io/test_vector_io.py | 10 ++++----- 4 files changed, 12 insertions(+), 34 deletions(-) delete mode 100644 tests/client-sdk/vector_io/conftest.py diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 639e5ee73..5790c498b 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -230,12 +230,11 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): if Api.telemetry in self.impls: setup_logger(self.impls[Api.telemetry]) - console = Console() - console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") - - # Redact sensitive information before printing - safe_config = redact_sensitive_fields(self.config.model_dump()) - console.print(yaml.dump(safe_config, indent=2)) + if not os.environ.get("PYTEST_CURRENT_TEST"): + console = Console() + console.print(f"Using config [blue]{self.config_path_or_template_name}[/blue]:") + safe_config = redact_sensitive_fields(self.config.model_dump()) + console.print(yaml.dump(safe_config, indent=2)) endpoints = get_all_api_endpoints() endpoint_impls = {} diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index 32aa5da3f..ac421475f 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -14,6 +14,7 @@ from llama_stack.apis.inference import ( ModelStore, TextTruncation, ) +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str EMBEDDING_MODELS = {} @@ -34,7 +35,7 @@ class SentenceTransformerEmbeddingMixin: ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode(contents) + embeddings = embedding_model.encode([interleaved_content_as_str(content) for content in contents]) return EmbeddingsResponse(embeddings=embeddings) def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": diff --git a/tests/client-sdk/vector_io/conftest.py b/tests/client-sdk/vector_io/conftest.py deleted file mode 100644 index 64cac27d2..000000000 --- a/tests/client-sdk/vector_io/conftest.py +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -def pytest_addoption(parser): - parser.addoption( - "--embedding-model", - action="store", - default="all-MiniLM-L6-v2", - help="Specify the embedding model to use for testing", - ) - - -def pytest_generate_tests(metafunc): - if "embedding_model" in metafunc.fixturenames: - metafunc.parametrize( - "embedding_model", - [metafunc.config.getoption("--embedding-model")], - ) diff --git a/tests/client-sdk/vector_io/test_vector_io.py b/tests/client-sdk/vector_io/test_vector_io.py index c7e4040b6..e093548b5 100644 --- a/tests/client-sdk/vector_io/test_vector_io.py +++ b/tests/client-sdk/vector_io/test_vector_io.py @@ -36,12 +36,12 @@ def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry @pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) -def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): +def test_vector_db_retrieve(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id): # Register a memory bank first vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, - embedding_model=embedding_model, + embedding_model=embedding_model_id, embedding_dimension=384, provider_id=provider_id, ) @@ -50,7 +50,7 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db response = llama_stack_client.vector_dbs.retrieve(vector_db_id=vector_db_id) assert response is not None assert response.identifier == vector_db_id - assert response.embedding_model == embedding_model + assert response.embedding_model == embedding_model_id assert response.provider_id == provider_id assert response.provider_resource_id == vector_db_id @@ -61,11 +61,11 @@ def test_vector_db_list(llama_stack_client, empty_vector_db_registry): @pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) -def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): +def test_vector_db_register(llama_stack_client, embedding_model_id, empty_vector_db_registry, provider_id): vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, - embedding_model=embedding_model, + embedding_model=embedding_model_id, embedding_dimension=384, provider_id=provider_id, ) From bf38d0aba0e2a526c91591268bc2ed4d4b3f90b3 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 21 Feb 2025 15:24:28 -0800 Subject: [PATCH 30/45] test: fix test_rag_agent test (#1215) Summary: Test Plan: LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py::test_rag_agent --safety-shield meta-llama/Llama-Guard-3-8B --- tests/client-sdk/agents/test_agents.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 23ae601e4..7ede5e517 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -496,10 +496,11 @@ def test_rag_agent(llama_stack_client, agent_config): stream=False, ) # rag is called - assert response.steps[0].tool_calls[0].tool_name == "query_from_memory" + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory" # document ids are present in metadata - assert "num-0" in response.steps[0].tool_responses[0].metadata["document_ids"] - assert expected_kw in response.output_message.content + assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"] + assert expected_kw in response.output_message.content.lower() def test_rag_and_code_agent(llama_stack_client, agent_config): From 45ffe87d7c75c1b9fad6b3074882521cc71367a4 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 15:37:23 -0800 Subject: [PATCH 31/45] Kill noise from test output --- .../providers/inline/agents/meta_reference/agents.py | 9 +-------- llama_stack/providers/utils/inference/embedding_mixin.py | 4 +++- tests/client-sdk/agents/test_agents.py | 2 -- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8a4d91238..72c1a0f34 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -11,8 +11,6 @@ import tempfile import uuid from typing import AsyncGenerator, List, Optional, Union -from termcolor import colored - from llama_stack.apis.agents import ( AgentConfig, AgentCreateResponse, @@ -69,12 +67,7 @@ class MetaReferenceAgentsImpl(Agents): # check if "bwrap" is available if not shutil.which("bwrap"): - print( - colored( - "Warning: `bwrap` is not available. Code interpreter tool will not work correctly.", - "yellow", - ) - ) + logger.warning("Warning: `bwrap` is not available. Code interpreter tool will not work correctly.") async def create_agent( self, diff --git a/llama_stack/providers/utils/inference/embedding_mixin.py b/llama_stack/providers/utils/inference/embedding_mixin.py index ac421475f..f43475554 100644 --- a/llama_stack/providers/utils/inference/embedding_mixin.py +++ b/llama_stack/providers/utils/inference/embedding_mixin.py @@ -35,7 +35,9 @@ class SentenceTransformerEmbeddingMixin: ) -> EmbeddingsResponse: model = await self.model_store.get_model(model_id) embedding_model = self._load_sentence_transformer_model(model.provider_resource_id) - embeddings = embedding_model.encode([interleaved_content_as_str(content) for content in contents]) + embeddings = embedding_model.encode( + [interleaved_content_as_str(content) for content in contents], show_progress_bar=False + ) return EmbeddingsResponse(embeddings=embeddings) def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer": diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 7ede5e517..c03a2a874 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -90,7 +90,6 @@ class TestClientTool(ClientTool): def agent_config(llama_stack_client, text_model_id): available_shields = [shield.identifier for shield in llama_stack_client.shields.list()] available_shields = available_shields[:1] - print(f"Using shield: {available_shields}") agent_config = AgentConfig( model=text_model_id, instructions="You are a helpful assistant", @@ -489,7 +488,6 @@ def test_rag_agent(llama_stack_client, agent_config): ), ] for prompt, expected_kw in user_prompts: - print(f"User> {prompt}") response = rag_agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, From 5be628f637bc0b5f7adfa4950d950e753ba6d67f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 16:25:51 -0800 Subject: [PATCH 32/45] Add test jsons to MANIFEST for now --- MANIFEST.in | 1 + 1 file changed, 1 insertion(+) diff --git a/MANIFEST.in b/MANIFEST.in index 9d9048983..ec45d8f08 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,3 +3,4 @@ include distributions/dependencies.json include llama_stack/distribution/*.sh include llama_stack/cli/scripts/*.sh include llama_stack/templates/*/*.yaml +include llama_stack/providers/tests/test_cases/*.json From 187524d4aeb7477297e580d2bbace7481109ca75 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Sat, 22 Feb 2025 08:38:10 +0800 Subject: [PATCH 33/45] feat: add substring search for model list (#1099) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] `llama model list` or `llama model list --show-all` will list more or all for the models, so add the `search` option to simplify the output. ``` $ llama model list --help usage: llama model list [-h] [--show-all] [-s SEARCH] Show available llama models options: -h, --help show this help message and exit --show-all Show all models (not just defaults) -s SEARCH, --search SEARCH Search for the input string as a substring in the model descriptor(ID) $ llama model list -s 70b +-----------------------+-----------------------------------+----------------+ | Model Descriptor(ID) | Hugging Face Repo | Context Length | +-----------------------+-----------------------------------+----------------+ | Llama3.1-70B | meta-llama/Llama-3.1-70B | 128K | +-----------------------+-----------------------------------+----------------+ | Llama3.1-70B-Instruct | meta-llama/Llama-3.1-70B-Instruct | 128K | +-----------------------+-----------------------------------+----------------+ | Llama3.3-70B-Instruct | meta-llama/Llama-3.3-70B-Instruct | 128K | +-----------------------+-----------------------------------+----------------+ $ llama model list -s 3.1-8b +----------------------+----------------------------------+----------------+ | Model Descriptor(ID) | Hugging Face Repo | Context Length | +----------------------+----------------------------------+----------------+ | Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K | +----------------------+----------------------------------+----------------+ | Llama3.1-8B-Instruct | meta-llama/Llama-3.1-8B-Instruct | 128K | +----------------------+----------------------------------+----------------+ $ llama model list --show-all -s pro +----------------------+-----------------------------+----------------+ | Model Descriptor(ID) | Hugging Face Repo | Context Length | +----------------------+-----------------------------+----------------+ | Prompt-Guard-86M | meta-llama/Prompt-Guard-86M | 2K | +----------------------+-----------------------------+----------------+ $ llama model list -s k Not found for search. ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: reidliu Co-authored-by: reidliu --- llama_stack/cli/model/list.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/llama_stack/cli/model/list.py b/llama_stack/cli/model/list.py index 622a6b4e7..b9499f06d 100644 --- a/llama_stack/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -75,6 +75,13 @@ class ModelList(Subcommand): action="store_true", help="List the downloaded models", ) + self.parser.add_argument( + "-s", + "--search", + type=str, + required=False, + help="Search for the input string as a substring in the model descriptor(ID)", + ) def _run_model_list_cmd(self, args: argparse.Namespace) -> None: from .safety_models import prompt_guard_model_sku @@ -94,15 +101,19 @@ class ModelList(Subcommand): continue descriptor = model.descriptor() - rows.append( - [ - descriptor, - model.huggingface_repo, - f"{model.max_seq_length // 1024}K", - ] + if not args.search or args.search.lower() in descriptor.lower(): + rows.append( + [ + descriptor, + model.huggingface_repo, + f"{model.max_seq_length // 1024}K", + ] + ) + if len(rows) == 0: + print(f"Did not find any model matching `{args.search}`.") + else: + print_table( + rows, + headers, + separate_rows=True, ) - print_table( - rows, - headers, - separate_rows=True, - ) From c9e08cc0a8bc02fc1c6a89a7c33751fa13d13a5d Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 21 Feb 2025 16:38:56 -0800 Subject: [PATCH 34/45] test: do not overwrite agent_config (#1216) Summary: Test Plan: --- tests/client-sdk/agents/test_agents.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index c03a2a874..1afec2cb1 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -332,8 +332,11 @@ def test_tool_choice(llama_stack_client, agent_config): ] client_tool = TestClientTool() for tool_choice, expected_tool in data: - agent_config["tool_config"] = {"tool_choice": tool_choice} - agent_config["client_tools"] = [client_tool.get_tool_definition()] + agent_config = { + **agent_config, + "tool_config": {"tool_choice": tool_choice}, + "client_tools": [client_tool.get_tool_definition()], + } agent = Agent(llama_stack_client, agent_config, client_tools=(client_tool,)) session_id = agent.create_session(f"test-session-{uuid4()}") From b890d7a611b3d45b0ffcdab5275f721af9dbfd99 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 21 Feb 2025 16:43:00 -0800 Subject: [PATCH 35/45] Test be not having prints yo --- tests/client-sdk/agents/test_agents.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 1afec2cb1..e5606b50b 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -384,7 +384,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - print(logs_str) # can't tell a joke: "I don't have a function" assert "function" in logs_str @@ -423,7 +422,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - print(logs_str) assert "bicycle" in logs_str response = agent.create_turn( @@ -438,7 +436,6 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config): logs = [str(log) for log in EventLogger().log(response) if log is not None] logs_str = "".join(logs) - print(logs_str) assert "-100" in logs_str assert "get_boiling_point" in logs_str @@ -557,7 +554,6 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): ] for prompt, docs, tool_name in user_prompts: - print(f"User> {prompt}") session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( messages=[{"role": "user", "content": prompt}], From 19ae4b35d9d22841ca14f30166d4b317554bd28d Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Sat, 22 Feb 2025 12:59:34 -0700 Subject: [PATCH 36/45] docs: Adding Provider sections to docs (#1195) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Adding Provider sections to docs (some of these will be empty and need updating). This PR is still a draft while I seek feedback from other contributors. I opened it to make the structure visible in the linked GitHub Issue. # Closes https://github.com/meta-llama/llama-stack/issues/1189 - Providers Overview Page ![Screenshot 2025-02-21 at 12 15 09 PM](https://github.com/user-attachments/assets/e83e5a17-0d96-4de0-8251-68161799a054) - SQLite-Vec specific page ![Screenshot 2025-02-21 at 12 15 34 PM](https://github.com/user-attachments/assets/14773900-fc8f-49e9-832a-b060b7ca010a) ## Test Plan N/A [//]: # (## Documentation) --------- Signed-off-by: Francisco Javier Arceo --- docs/source/concepts/index.md | 2 +- docs/source/conf.py | 2 +- docs/source/index.md | 2 + docs/source/providers/index.md | 59 +++++++++++++++++++ docs/source/providers/vector_io/chromadb.md | 36 +++++++++++ docs/source/providers/vector_io/faiss.md | 33 +++++++++++ docs/source/providers/vector_io/pgvector.md | 31 ++++++++++ docs/source/providers/vector_io/qdrant.md | 31 ++++++++++ docs/source/providers/vector_io/sqlite-vec.md | 33 +++++++++++ docs/source/providers/vector_io/weaviate.md | 33 +++++++++++ 10 files changed, 260 insertions(+), 2 deletions(-) create mode 100644 docs/source/providers/index.md create mode 100644 docs/source/providers/vector_io/chromadb.md create mode 100644 docs/source/providers/vector_io/faiss.md create mode 100644 docs/source/providers/vector_io/pgvector.md create mode 100644 docs/source/providers/vector_io/qdrant.md create mode 100644 docs/source/providers/vector_io/sqlite-vec.md create mode 100644 docs/source/providers/vector_io/weaviate.md diff --git a/docs/source/concepts/index.md b/docs/source/concepts/index.md index df46e0134..27eb74f00 100644 --- a/docs/source/concepts/index.md +++ b/docs/source/concepts/index.md @@ -33,7 +33,7 @@ Providers come in two flavors: - **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code. - **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack. -Most importantly, Llama Stack always strives to provide at least one fully "local" provider for each API so you can iterate on a fully featured environment locally. +Most importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally. ## Resources Some of these APIs are associated with a set of **Resources**. Here is the mapping of APIs to resources: diff --git a/docs/source/conf.py b/docs/source/conf.py index a876333db..fd105a6cf 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -15,7 +15,7 @@ from docutils import nodes project = "llama-stack" -copyright = "2024, Meta" +copyright = "2025, Meta" author = "Meta" # -- General configuration --------------------------------------------------- diff --git a/docs/source/index.md b/docs/source/index.md index cb2355bfd..b6fd314b7 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -67,6 +67,7 @@ A number of "adapters" are available for some popular Inference and Vector Store | **Provider** | **Environments** | | :----: | :----: | | FAISS | Single Node | +| SQLite-Vec| Single Node | | Chroma | Hosted and Single Node | | Postgres (PGVector) | Hosted and Single Node | | Weaviate | Hosted | @@ -88,6 +89,7 @@ self introduction/index getting_started/index concepts/index +providers/index distributions/index distributions/selection building_applications/index diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md new file mode 100644 index 000000000..cc654823e --- /dev/null +++ b/docs/source/providers/index.md @@ -0,0 +1,59 @@ +# Providers Overview + +The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: +- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.), +- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), +- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) + +Providers come in two flavors: +- **Remote**: the provider runs as a separate service external to the Llama Stack codebase. Llama Stack contains a small amount of adapter code. +- **Inline**: the provider is fully specified and implemented within the Llama Stack codebase. It may be a simple wrapper around an existing library, or a full fledged implementation within Llama Stack. + +Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally. + +## Agents +Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc. + +## DatasetIO +Interfaces with datasets and data loaders. + +## Eval +Generates outputs (via Inference or Agents) and perform scoring. + +## Inference +Runs inference with an LLM. + +## Post Training +Fine-tunes a model. + +## Safety +Applies safety policies to the output at a Systems (not only model) level. + +## Scoring +Evaluates the outputs of the system. + +## Telemetry +Collects telemetry data from the system. + +## Tool Runtime +Is associated with the ToolGroup resouces. + +## Vector IO + +Vector IO refers to operations on vector databases, such as adding documents, searching, and deleting documents. +Vector IO plays a crucial role in [Retreival Augmented Generation (RAG)](../..//building_applications/rag), where the vector +io and database are used to store and retrieve documents for retrieval. + +#### Vector IO Providers +The following providers (i.e., databases) are available for Vector IO: + +```{toctree} +:maxdepth: 1 + +vector_io/faiss +vector_io/sqlite-vec +vector_io/chromadb +vector_io/pgvector +vector_io/qdrant +vector_io/weaviate +``` diff --git a/docs/source/providers/vector_io/chromadb.md b/docs/source/providers/vector_io/chromadb.md new file mode 100644 index 000000000..4a7caf2e1 --- /dev/null +++ b/docs/source/providers/vector_io/chromadb.md @@ -0,0 +1,36 @@ +--- +orphan: true +--- +# Chroma + +[Chroma](https://www.trychroma.com/) is an inline and remote vector +database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. +That means you're not limited to storing vectors in memory or in a separate service. + +## Features +Chroma supports: +- Store embeddings and their metadata +- Vector search +- Full-text search +- Document storage +- Metadata filtering +- Multi-modal retrieval + +## Usage + +To use Chrome in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use chroma. +3. Start storing and querying vectors. + +## Installation + +You can install chroma using pip: + +```bash +pip install chromadb +``` + +## Documentation +See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general. diff --git a/docs/source/providers/vector_io/faiss.md b/docs/source/providers/vector_io/faiss.md new file mode 100644 index 000000000..f894190eb --- /dev/null +++ b/docs/source/providers/vector_io/faiss.md @@ -0,0 +1,33 @@ +--- +orphan: true +--- +# Faiss + +[Faiss](https://github.com/facebookresearch/faiss) is an inline vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Lightweight and easy to use +- Fully integrated with Llama Stack +- GPU support + +## Usage + +To use Faiss in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use Faiss. +3. Start storing and querying vectors. + +## Installation + +You can install Faiss using pip: + +```bash +pip install faiss-cpu +``` +## Documentation +See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for +more details about Faiss in general. diff --git a/docs/source/providers/vector_io/pgvector.md b/docs/source/providers/vector_io/pgvector.md new file mode 100644 index 000000000..919eb88d8 --- /dev/null +++ b/docs/source/providers/vector_io/pgvector.md @@ -0,0 +1,31 @@ +--- +orphan: true +--- +# Postgres PGVector + +[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use PGVector in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use Faiss. +3. Start storing and querying vectors. + +## Installation + +You can install PGVector using docker: + +```bash +docker pull pgvector/pgvector:pg17 +``` +## Documentation +See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. diff --git a/docs/source/providers/vector_io/qdrant.md b/docs/source/providers/vector_io/qdrant.md new file mode 100644 index 000000000..c374ade98 --- /dev/null +++ b/docs/source/providers/vector_io/qdrant.md @@ -0,0 +1,31 @@ +--- +orphan: true +--- +# Qdrant + +[Qdrant](https://qdrant.tech/documentation/) is a remote vector database provider for Llama Stack. It +allows you to store and query vectors directly in memory. +That means you'll get fast and efficient vector retrieval. + +## Features + +- Easy to use +- Fully integrated with Llama Stack + +## Usage + +To use Qdrant in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use Faiss. +3. Start storing and querying vectors. + +## Installation + +You can install Qdrant using docker: + +```bash +docker pull qdrant/qdrant +``` +## Documentation +See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general. diff --git a/docs/source/providers/vector_io/sqlite-vec.md b/docs/source/providers/vector_io/sqlite-vec.md new file mode 100644 index 000000000..f5ce4c003 --- /dev/null +++ b/docs/source/providers/vector_io/sqlite-vec.md @@ -0,0 +1,33 @@ +--- +orphan: true +--- +# SQLite-Vec + +[SQLite-Vec](https://github.com/asg017/sqlite-vec) is an inline vector database provider for Llama Stack. It +allows you to store and query vectors directly within an SQLite database. +That means you're not limited to storing vectors in memory or in a separate service. + +## Features + +- Lightweight and easy to use +- Fully integrated with Llama Stack + +## Usage + +To use SQLite-Vec in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use SQLite-Vec. +3. Start storing and querying vectors. + +## Installation + +You can install SQLite-Vec using pip: + +```bash +pip install sqlite-vec +``` + +## Documentation + +See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) for more details about sqlite-vec in general. diff --git a/docs/source/providers/vector_io/weaviate.md b/docs/source/providers/vector_io/weaviate.md new file mode 100644 index 000000000..47321781c --- /dev/null +++ b/docs/source/providers/vector_io/weaviate.md @@ -0,0 +1,33 @@ +--- +orphan: true +--- +# Weaviate + +[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. +It allows you to store and query vectors directly within a Weaviate database. +That means you're not limited to storing vectors in memory or in a separate service. + +## Features +Weaviate supports: +- Store embeddings and their metadata +- Vector search +- Full-text search +- Hybrid search +- Document storage +- Metadata filtering +- Multi-modal retrieval + +## Usage + +To use Weaviate in your Llama Stack project, follow these steps: + +1. Install the necessary dependencies. +2. Configure your Llama Stack project to use chroma. +3. Start storing and querying vectors. + +## Installation + +To install Weaviate see the [Weaviate quickstart documentation](https://weaviate.io/developers/weaviate/quickstart). + +## Documentation +See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. From 6227e1e3b9a1164000b18286791dccdf2a2933d9 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 23 Feb 2025 16:57:11 -0800 Subject: [PATCH 37/45] fix: update virtualenv building so llamastack- prefix is not added, make notebook experience easier (#1225) Make sure venv behaves like conda (no prefix is added to image_name) and `--image-type venv` inside a notebook "just works" without any fiddling --- .pre-commit-config.yaml | 1 + llama_stack/cli/stack/_build.py | 16 ++++++++++++++-- llama_stack/distribution/build_venv.sh | 13 +++++++++---- llama_stack/distribution/library_client.py | 14 +------------- llama_stack/distribution/start_venv.sh | 1 + llama_stack/distribution/utils/exec.py | 13 +++++++++++++ 6 files changed, 39 insertions(+), 19 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 85cb1b91a..70af72a62 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -30,6 +30,7 @@ repos: rev: v0.9.4 hooks: - id: ruff + args: [ --fix ] exclude: ^llama_stack/strong_typing/.*$ - id: ruff-format diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 76f03aa5c..666c2e6dd 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -37,6 +37,7 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.dynamic import instantiate_class_type +from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.datatypes import Api TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" @@ -59,8 +60,16 @@ def run_stack_build_command(args: argparse.Namespace) -> None: if args.list_templates: return _run_template_list_cmd() - current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") - image_name = args.image_name or current_conda_env + if args.image_type == "venv": + current_venv = os.environ.get("VIRTUAL_ENV") + image_name = args.image_name or current_venv + if not image_name and in_notebook(): + image_name = "__system__" + elif args.image_type == "conda": + current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") + image_name = args.image_name or current_conda_env + else: + image_name = args.image_name if args.template: available_templates = available_templates_specs() @@ -256,6 +265,9 @@ def _run_stack_build_command_from_build_config( elif build_config.image_type == ImageType.conda.value: if not image_name: raise ValueError("Please specify an image name when building a conda image") + elif build_config.image_type == ImageType.venv.value: + if not image_name: + raise ValueError("Please specify an image name when building a venv image") if template_name: build_dir = DISTRIBS_BASE_DIR / template_name diff --git a/llama_stack/distribution/build_venv.sh b/llama_stack/distribution/build_venv.sh index b47cfcb83..f973fe955 100755 --- a/llama_stack/distribution/build_venv.sh +++ b/llama_stack/distribution/build_venv.sh @@ -16,6 +16,7 @@ TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} # Reference: https://github.com/astral-sh/uv/pull/1694 UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} UV_SYSTEM_PYTHON=${UV_SYSTEM_PYTHON:-} +VIRTUAL_ENV=${VIRTUAL_ENV:-} if [ -n "$LLAMA_STACK_DIR" ]; then echo "Using llama-stack-dir=$LLAMA_STACK_DIR" @@ -25,7 +26,7 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then fi if [ "$#" -lt 3 ]; then - echo "Usage: $0 []" >&2 + echo "Usage: $0 []" >&2 echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2 exit 1 fi @@ -34,8 +35,7 @@ special_pip_deps="$3" set -euo pipefail -build_name="$1" -env_name="llamastack-$build_name" +env_name="$1" pip_dependencies="$2" # Define color codes @@ -75,8 +75,12 @@ run() { local pip_dependencies="$2" local special_pip_deps="$3" - if [ -n "$UV_SYSTEM_PYTHON" ]; then + if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then echo "Installing dependencies in system Python environment" + # if env == __system__, ensure we set UV_SYSTEM_PYTHON + export UV_SYSTEM_PYTHON=1 + elif [ "$VIRTUAL_ENV" == "$env_name" ]; then + echo "Virtual environment $env_name is already active" else echo "Using virtual environment $env_name" uv venv "$env_name" @@ -90,6 +94,7 @@ run() { # shellcheck disable=SC2086 # we are building a command line so word splitting is expected uv pip install --extra-index-url https://test.pypi.org/simple/ \ + --index-strategy unsafe-best-match \ llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \ $pip_dependencies if [ -n "$special_pip_deps" ]; then diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 5790c498b..59189f8bb 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -41,6 +41,7 @@ from llama_stack.distribution.stack import ( redact_sensitive_fields, replace_env_vars, ) +from llama_stack.distribution.utils.exec import in_notebook from llama_stack.providers.utils.telemetry.tracing import ( end_trace, setup_logger, @@ -52,19 +53,6 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -def in_notebook(): - try: - from IPython import get_ipython - - if "IPKernelApp" not in get_ipython().config: # pragma: no cover - return False - except ImportError: - return False - except AttributeError: - return False - return True - - def convert_pydantic_to_json_value(value: Any) -> Any: if isinstance(value, Enum): return value.value diff --git a/llama_stack/distribution/start_venv.sh b/llama_stack/distribution/start_venv.sh index 1cfa7248f..195274129 100755 --- a/llama_stack/distribution/start_venv.sh +++ b/llama_stack/distribution/start_venv.sh @@ -55,6 +55,7 @@ while [[ $# -gt 0 ]]; do esac done +echo "Using virtual environment: $venv_path" # Activate virtual environment if [ ! -d "$venv_path" ]; then echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2 diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index 4a3a95826..e13e59aad 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -22,6 +22,19 @@ def run_with_pty(command): return _run_with_pty_unix(command) +def in_notebook(): + try: + from IPython import get_ipython + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover + return False + except ImportError: + return False + except AttributeError: + return False + return True + + # run a command in a pseudo-terminal, with interrupt handling, # useful when you want to run interactive things def _run_with_pty_unix(command): From 34e3faa4e833e5b3dea9310de3b54e97413b14f8 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Sun, 23 Feb 2025 22:06:09 -0500 Subject: [PATCH 38/45] feat: add --run to llama stack build (#1156) # What does this PR do? --run runs the stack that was just build using the same arguments during the build process (image-name, type, etc) This simplifies the workflow a lot and makes the UX better for most local users trying to get started rather than having to match the flags of the two commands (build and then run) Also, moved `ImageType` to distribution.utils since there were circular import errors with its old location ## Test Plan tested locally using the following command: `llama stack build --run --template ollama --image-type venv` Signed-off-by: Charlie Doern --- llama_stack/cli/stack/_build.py | 47 +++++++++--- llama_stack/cli/stack/build.py | 7 ++ llama_stack/cli/stack/run.py | 71 +------------------ llama_stack/distribution/build.py | 8 +-- llama_stack/distribution/utils/exec.py | 70 ++++++++++++++++++ llama_stack/distribution/utils/image_types.py | 13 ++++ 6 files changed, 129 insertions(+), 87 deletions(-) create mode 100644 llama_stack/distribution/utils/image_types.py diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 666c2e6dd..97d8900df 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -23,10 +23,10 @@ from termcolor import cprint from llama_stack.cli.table import print_table from llama_stack.distribution.build import ( SERVER_DEPENDENCIES, - ImageType, build_image, get_provider_dependencies, ) +from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.datatypes import ( BuildConfig, DistributionSpec, @@ -37,7 +37,8 @@ from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.resolver import InvalidProviderError from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.dynamic import instantiate_class_type -from llama_stack.distribution.utils.exec import in_notebook +from llama_stack.distribution.utils.exec import formulate_run_args, in_notebook, run_with_pty +from llama_stack.distribution.utils.image_types import ImageType from llama_stack.providers.datatypes import Api TEMPLATES_PATH = Path(__file__).parent.parent.parent / "templates" @@ -186,19 +187,41 @@ def run_stack_build_command(args: argparse.Namespace) -> None: print(f"uv pip install {special_dep}") return - _run_stack_build_command_from_build_config( - build_config, - image_name=image_name, - config_path=args.config, - template_name=args.template, - ) + try: + run_config = _run_stack_build_command_from_build_config( + build_config, + image_name=image_name, + config_path=args.config, + template_name=args.template, + ) + + except Exception as exc: + cprint( + f"Error building stack: {exc}", + color="red", + ) + return + if run_config is None: + cprint( + "Run config path is empty", + color="red", + ) + return + + if args.run: + run_config = Path(run_config) + config_dict = yaml.safe_load(run_config.read_text()) + config = parse_and_maybe_upgrade_config(config_dict) + run_args = formulate_run_args(args.image_type, args.image_name, config, args.template) + run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))]) + run_with_pty(run_args) def _generate_run_config( build_config: BuildConfig, build_dir: Path, image_name: str, -) -> None: +) -> str: """ Generate a run.yaml template file for user to edit from a build.yaml file """ @@ -248,6 +271,7 @@ def _generate_run_config( f"You can now run your stack with `llama stack run {run_config_file}`", color="green", ) + return run_config_file def _run_stack_build_command_from_build_config( @@ -255,7 +279,7 @@ def _run_stack_build_command_from_build_config( image_name: Optional[str] = None, template_name: Optional[str] = None, config_path: Optional[str] = None, -) -> None: +) -> str: if build_config.image_type == ImageType.container.value: if template_name: image_name = f"distribution-{template_name}" @@ -298,8 +322,9 @@ def _run_stack_build_command_from_build_config( shutil.copy(path, run_config_file) cprint("Build Successful!", color="green") + return template_path else: - _generate_run_config(build_config, build_dir, image_name) + return _generate_run_config(build_config, build_dir, image_name) def _run_template_list_cmd() -> None: diff --git a/llama_stack/cli/stack/build.py b/llama_stack/cli/stack/build.py index 7b17a960a..ceee725e6 100644 --- a/llama_stack/cli/stack/build.py +++ b/llama_stack/cli/stack/build.py @@ -68,6 +68,13 @@ the build. If not specified, currently active Conda environment will be used if help="Print the dependencies for the stack only, without building the stack", ) + self.parser.add_argument( + "--run", + action="store_true", + default=False, + help="Run the stack after building using the same image type, name, and other applicable arguments", + ) + def _run_stack_build_command(self, args: argparse.Namespace) -> None: # always keep implementation completely silo-ed away from CLI so CLI # can be fast to load and reduces dependencies diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 0c9c74518..627ee829a 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -74,10 +74,6 @@ class StackRun(Subcommand): ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: - import importlib.resources - import json - import subprocess - import yaml from termcolor import cprint @@ -87,7 +83,7 @@ class StackRun(Subcommand): BUILDS_BASE_DIR, DISTRIBS_BASE_DIR, ) - from llama_stack.distribution.utils.exec import run_with_pty + from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty if not args.config: self.parser.error("Must specify a config file to run") @@ -125,70 +121,7 @@ class StackRun(Subcommand): config_dict = yaml.safe_load(config_file.read_text()) config = parse_and_maybe_upgrade_config(config_dict) - if args.image_type == ImageType.container.value or config.container_image: - script = importlib.resources.files("llama_stack") / "distribution/start_container.sh" - image_name = f"distribution-{template_name}" if template_name else config.container_image - run_args = [script, image_name] - elif args.image_type == ImageType.conda.value: - current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") - image_name = args.image_name or current_conda_env - if not image_name: - cprint( - "No current conda environment detected, please specify a conda environment name with --image-name", - color="red", - ) - return - - def get_conda_prefix(env_name): - # Conda "base" environment does not end with "base" in the - # prefix, so should be handled separately. - if env_name == "base": - return os.environ.get("CONDA_PREFIX") - # Get conda environments info - conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode()) - envs = conda_env_info["envs"] - for envpath in envs: - if envpath.endswith(env_name): - return envpath - return None - - print(f"Using conda environment: {image_name}") - conda_prefix = get_conda_prefix(image_name) - if not conda_prefix: - cprint( - f"Conda environment {image_name} does not exist.", - color="red", - ) - return - - build_file = Path(conda_prefix) / "llamastack-build.yaml" - if not build_file.exists(): - cprint( - f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name", - color="red", - ) - return - - script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh" - run_args = [ - script, - image_name, - ] - else: - # else must be venv since that is the only valid option left. - current_venv = os.environ.get("VIRTUAL_ENV") - venv = args.image_name or current_venv - if not venv: - cprint( - "No current virtual environment detected, please specify a virtual environment name with --image-name", - color="red", - ) - return - script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh" - run_args = [ - script, - venv, - ] + run_args = formulate_run_args(args.image_type, args.image_name, config, template_name) run_args.extend([str(config_file), str(args.port)]) if args.disable_ipv6: diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 511817de8..2b43b8128 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -7,7 +7,6 @@ import importlib.resources import logging import sys -from enum import Enum from pathlib import Path from typing import Dict, List @@ -18,6 +17,7 @@ from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.exec import run_command, run_with_pty +from llama_stack.distribution.utils.image_types import ImageType from llama_stack.providers.datatypes import Api log = logging.getLogger(__name__) @@ -33,12 +33,6 @@ SERVER_DEPENDENCIES = [ ] -class ImageType(Enum): - container = "container" - conda = "conda" - venv = "venv" - - class ApiInput(BaseModel): api: Api provider: str diff --git a/llama_stack/distribution/utils/exec.py b/llama_stack/distribution/utils/exec.py index e13e59aad..00afdadbe 100644 --- a/llama_stack/distribution/utils/exec.py +++ b/llama_stack/distribution/utils/exec.py @@ -12,8 +12,78 @@ import signal import subprocess import sys +from termcolor import cprint + log = logging.getLogger(__name__) +import importlib +import json +from pathlib import Path + +from llama_stack.distribution.utils.image_types import ImageType + + +def formulate_run_args(image_type, image_name, config, template_name) -> list: + if image_type == ImageType.container.value or config.container_image: + script = importlib.resources.files("llama_stack") / "distribution/start_container.sh" + image_name = f"distribution-{template_name}" if template_name else config.container_image + run_args = [script, image_name] + elif image_type == ImageType.conda.value: + current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") + image_name = image_name or current_conda_env + if not image_name: + cprint( + "No current conda environment detected, please specify a conda environment name with --image-name", + color="red", + ) + return + + def get_conda_prefix(env_name): + # Conda "base" environment does not end with "base" in the + # prefix, so should be handled separately. + if env_name == "base": + return os.environ.get("CONDA_PREFIX") + # Get conda environments info + conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode()) + envs = conda_env_info["envs"] + for envpath in envs: + if envpath.endswith(env_name): + return envpath + return None + + print(f"Using conda environment: {image_name}") + conda_prefix = get_conda_prefix(image_name) + if not conda_prefix: + cprint( + f"Conda environment {image_name} does not exist.", + color="red", + ) + return + + build_file = Path(conda_prefix) / "llamastack-build.yaml" + if not build_file.exists(): + cprint( + f"Build file {build_file} does not exist.\n\nPlease run `llama stack build` or specify the correct conda environment name with --image-name", + color="red", + ) + return + + script = importlib.resources.files("llama_stack") / "distribution/start_conda_env.sh" + run_args = [ + script, + image_name, + ] + else: + # else must be venv since that is the only valid option left. + current_venv = os.environ.get("VIRTUAL_ENV") + venv = image_name or current_venv + script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh" + run_args = [ + script, + venv, + ] + return run_args + def run_with_pty(command): if sys.platform.startswith("win"): diff --git a/llama_stack/distribution/utils/image_types.py b/llama_stack/distribution/utils/image_types.py new file mode 100644 index 000000000..1a43b092f --- /dev/null +++ b/llama_stack/distribution/utils/image_types.py @@ -0,0 +1,13 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum + + +class ImageType(Enum): + container = "container" + conda = "conda" + venv = "venv" From 17162b997830789485613fe9882dc89c6f814c93 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Sun, 23 Feb 2025 23:16:30 -0500 Subject: [PATCH 39/45] docs: Add vLLM to the list of inference providers in concepts and providers pages (#1227) This increases visibility of the vLLM provider. --- docs/source/concepts/index.md | 2 +- docs/source/providers/index.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/concepts/index.md b/docs/source/concepts/index.md index 27eb74f00..c839266b6 100644 --- a/docs/source/concepts/index.md +++ b/docs/source/concepts/index.md @@ -25,7 +25,7 @@ We are working on adding a few more APIs to complete the application lifecycle. ## API Providers The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: -- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.), +- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), - Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) diff --git a/docs/source/providers/index.md b/docs/source/providers/index.md index cc654823e..e039e90b0 100644 --- a/docs/source/providers/index.md +++ b/docs/source/providers/index.md @@ -1,7 +1,7 @@ # Providers Overview The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include: -- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, etc.), +- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.), - Vector databases (e.g., ChromaDB, Weaviate, Qdrant, FAISS, PGVector, etc.), - Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.) From 0973d386e658a570edd88d3c6bf7869f6794b7d8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 23 Feb 2025 21:47:18 -0800 Subject: [PATCH 40/45] fix: update build_container.sh to ensure llama-models is installed first --- llama_stack/distribution/build_container.sh | 34 ++++++++++----------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 7c6d758c0..3a27c5046 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -107,6 +107,22 @@ fi stack_mount="/app/llama-stack-source" models_mount="/app/llama-models-source" +if [ -n "$LLAMA_MODELS_DIR" ]; then + if [ ! -d "$LLAMA_MODELS_DIR" ]; then + echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2 + exit 1 + fi + + if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then + add_to_container << EOF +COPY $LLAMA_MODELS_DIR $models_mount +EOF + fi + add_to_container << EOF +RUN uv pip install --no-cache -e $models_mount +EOF +fi + if [ -n "$LLAMA_STACK_DIR" ]; then if [ ! -d "$LLAMA_STACK_DIR" ]; then echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2 @@ -134,6 +150,7 @@ RUN uv pip install fastapi libcst EOF add_to_container << EOF RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \ + --index-strategy unsafe-best-match \ llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION EOF @@ -149,23 +166,6 @@ EOF fi fi -if [ -n "$LLAMA_MODELS_DIR" ]; then - if [ ! -d "$LLAMA_MODELS_DIR" ]; then - echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2 - exit 1 - fi - - if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then - add_to_container << EOF -COPY $LLAMA_MODELS_DIR $models_mount -EOF - fi - add_to_container << EOF -RUN uv pip uninstall llama-models -RUN uv pip install --no-cache $models_mount -EOF -fi - # if template_or_config ends with .yaml, it is not a template and we should not use the --template flag if [[ "$template_or_config" != *.yaml ]]; then add_to_container << EOF From 1842eeb96fbc3866ed908d4af4f228b2cf1b7831 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Mon, 24 Feb 2025 20:59:58 +0800 Subject: [PATCH 41/45] docs: small fixes (#1224) --- docs/source/distributions/selection.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions/selection.md b/docs/source/distributions/selection.md index da1b0df9c..269b14bce 100644 --- a/docs/source/distributions/selection.md +++ b/docs/source/distributions/selection.md @@ -17,7 +17,7 @@ Which templates / distributions to choose depends on the hardware you have for r - {dockerhub}`distribution-nvidia` ([Guide](self_hosted_distro/nvidia)) - **Are you running on a "regular" desktop or laptop ?** We suggest using the ollama template for quick prototyping and get started without having to worry about needing GPUs. - - {dockerhub}`distribution-ollama` ([link](self_hosted_distro/ollama)) + - {dockerhub}`distribution-ollama` ([Guide](self_hosted_distro/ollama)) - **Do you have an API key for a remote inference provider like Fireworks, Together, etc.?** If so, we suggest: - {dockerhub}`distribution-together` ([Guide](self_hosted_distro/together)) @@ -28,7 +28,7 @@ Which templates / distributions to choose depends on the hardware you have for r - [Android](ondevice_distro/android_sdk) -- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro).** +- **If none of the above fit your needs, you can also build your own [custom distribution](building_distro.md).** ### Distribution Details From 641549c63144a93fba2403b07d20e95ec6a9b83f Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 24 Feb 2025 07:51:02 -0800 Subject: [PATCH 42/45] Add llama stack client overrides also; necessary for correct docker building --- llama_stack/distribution/build_container.sh | 45 ++++++++++----------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 3a27c5046..022c0a41c 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -8,6 +8,8 @@ LLAMA_MODELS_DIR=${LLAMA_MODELS_DIR:-} LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-} +LLAMA_STACK_CLIENT_DIR=${LLAMA_STACK_CLIENT_DIR:-} + TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-} PYPI_VERSION=${PYPI_VERSION:-} BUILD_PLATFORM=${BUILD_PLATFORM:-} @@ -106,42 +108,39 @@ fi stack_mount="/app/llama-stack-source" models_mount="/app/llama-models-source" +client_mount="/app/llama-stack-client-source" -if [ -n "$LLAMA_MODELS_DIR" ]; then - if [ ! -d "$LLAMA_MODELS_DIR" ]; then - echo "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}" >&2 +install_local_package() { + local dir="$1" + local mount_point="$2" + local name="$3" + + if [ ! -d "$dir" ]; then + echo "${RED}Warning: $name is set but directory does not exist: $dir${NC}" >&2 exit 1 fi if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then add_to_container << EOF -COPY $LLAMA_MODELS_DIR $models_mount +COPY $dir $mount_point EOF fi add_to_container << EOF -RUN uv pip install --no-cache -e $models_mount +RUN uv pip install --no-cache -e $mount_point EOF +} + + +if [ -n "$LLAMA_MODELS_DIR" ]; then + install_local_package "$LLAMA_MODELS_DIR" "$models_mount" "LLAMA_MODELS_DIR" +fi + +if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then + install_local_package "$LLAMA_STACK_CLIENT_DIR" "$client_mount" "LLAMA_STACK_CLIENT_DIR" fi if [ -n "$LLAMA_STACK_DIR" ]; then - if [ ! -d "$LLAMA_STACK_DIR" ]; then - echo "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}" >&2 - exit 1 - fi - - # Install in editable format. We will mount the source code into the container - # so that changes will be reflected in the container without having to do a - # rebuild. This is just for development convenience. - - if [ "$USE_COPY_NOT_MOUNT" = "true" ]; then - add_to_container << EOF -COPY $LLAMA_STACK_DIR $stack_mount -EOF - fi - - add_to_container << EOF -RUN uv pip install --no-cache -e $stack_mount -EOF + install_local_package "$LLAMA_STACK_DIR" "$stack_mount" "LLAMA_STACK_DIR" else if [ -n "$TEST_PYPI_VERSION" ]; then # these packages are damaged in test-pypi, so install them first From e8e8fe7c93fc3289414c7b9f50b313f6ee5a29d8 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 24 Feb 2025 10:00:57 -0800 Subject: [PATCH 43/45] fix: add LLAMA_STACK_CLIENT_DIR mount when installing in docker from source --- llama_stack/distribution/build_container.sh | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 022c0a41c..5f595af2c 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -197,6 +197,9 @@ if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount" fi + if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then + mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount" + fi fi if command -v selinuxenabled &>/dev/null && selinuxenabled; then From d6356f822ab0adfea22d3767e4a53531819707a0 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 24 Feb 2025 10:05:02 -0800 Subject: [PATCH 44/45] fix: remove UV_SYSTEM_PYTHON from getting started notebook since llama stack build detects notebook environment --- docs/getting_started.ipynb | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 51ae945f4..7f9afd647 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -86,8 +86,6 @@ "# NBVAL_SKIP\n", "\n", "!apt-get install -y bubblewrap\n", - "import os\n", - "os.environ[\"UV_SYSTEM_PYTHON\"] = \"1\"\n", "!pip install uv\n", "!uv pip install llama-stack" ] @@ -3632,7 +3630,7 @@ "provenance": [] }, "kernelspec": { - "display_name": "master", + "display_name": "toolchain", "language": "python", "name": "python3" }, From c4987bc349bf9319bbe17ac7a201121cf4b34312 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Feb 2025 19:18:52 +0100 Subject: [PATCH 45/45] fix: avoid failure when no special pip deps and better exit (#1228) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? When building providers in a virtual environment or containers, special pip dependencies may not always be provided (e.g., for Ollama). The check should only fail if the required number of arguments is missing. Currently, two arguments are mandatory: 1. Environment name 2. Pip dependencies Additionally, return statements were replaced with sys.exit(1) in error conditions to ensure immediate termination on critical failures. Error handling in the stack build process was also improved to guarantee the program exits with status 1 when facing configuration issues or build failures. Signed-off-by: Sébastien Han [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan This command shouldn't fail: ``` llama stack build --template ollama --image-type venv ``` [//]: # (## Documentation) Signed-off-by: Sébastien Han --- llama_stack/cli/stack/_build.py | 17 +++++++++-------- llama_stack/distribution/build_container.sh | 2 +- llama_stack/distribution/build_venv.sh | 6 +++--- 3 files changed, 13 insertions(+), 12 deletions(-) diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 97d8900df..96382d428 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -9,6 +9,7 @@ import importlib.resources import json import os import shutil +import sys import textwrap from functools import lru_cache from pathlib import Path @@ -79,7 +80,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: f"Could not find template {args.template}. Please run `llama stack build --list-templates` to check out the available templates", color="red", ) - return + sys.exit(1) build_config = available_templates[args.template] if args.image_type: build_config.image_type = args.image_type @@ -88,7 +89,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None: f"Please specify a image-type (container | conda | venv) for {args.template}", color="red", ) - return + sys.exit(1) elif not args.config and not args.template: name = prompt( "> Enter a name for your Llama Stack (e.g. my-local-stack): ", @@ -169,14 +170,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None: f"Could not parse config file {args.config}: {e}", color="red", ) - return + sys.exit(1) if build_config.image_type == ImageType.container.value and not args.image_name: cprint( "Please specify --image-name when building a container from a config file", color="red", ) - return + sys.exit(1) if args.print_deps_only: print(f"# Dependencies for {args.template or args.config or image_name}") @@ -195,18 +196,18 @@ def run_stack_build_command(args: argparse.Namespace) -> None: template_name=args.template, ) - except Exception as exc: + except (Exception, RuntimeError) as exc: cprint( f"Error building stack: {exc}", color="red", ) - return + sys.exit(1) if run_config is None: cprint( "Run config path is empty", color="red", ) - return + sys.exit(1) if args.run: run_config = Path(run_config) @@ -312,7 +313,7 @@ def _run_stack_build_command_from_build_config( template_or_config=template_name or config_path, ) if return_code != 0: - return + raise RuntimeError(f"Failed to build image {image_name}") if template_name: # copy run.yaml from template to build_dir instead of generating it again diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index 5f595af2c..08941a538 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -34,7 +34,7 @@ container_base="$3" build_file_path="$4" host_build_dir="$5" pip_dependencies="$6" -special_pip_deps="$7" +special_pip_deps="${7:-}" # Define color codes diff --git a/llama_stack/distribution/build_venv.sh b/llama_stack/distribution/build_venv.sh index f973fe955..52c5c7051 100755 --- a/llama_stack/distribution/build_venv.sh +++ b/llama_stack/distribution/build_venv.sh @@ -25,7 +25,7 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then echo "Using llama-models-dir=$LLAMA_MODELS_DIR" fi -if [ "$#" -lt 3 ]; then +if [ "$#" -lt 2 ]; then echo "Usage: $0 []" >&2 echo "Example: $0 mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2 exit 1 @@ -74,8 +74,8 @@ run() { local env_name="$1" local pip_dependencies="$2" local special_pip_deps="$3" - - if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then + + if [ -n "$UV_SYSTEM_PYTHON" ] || [ "$env_name" == "__system__" ]; then echo "Installing dependencies in system Python environment" # if env == __system__, ensure we set UV_SYSTEM_PYTHON export UV_SYSTEM_PYTHON=1