From b34c1dd8ada4b3c42567a335040ef2ee2b66c656 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 11 Feb 2025 04:38:11 +0100 Subject: [PATCH 01/27] test: replace blocked image URLs with GitHub-hosted (#1025) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The previous image URLs were sometimes blocked by Cloudflare, causing test failures for some users. This update replaces them with a GitHub-hosted image (`dog.png`) from the `llama-stack` repository, ensuring more reliable access during testing. Signed-off-by: Sébastien Han [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan ``` $ ollama run llama3.2-vision:latest --keep-alive 2m & $ uv run pytest -v -s -k "ollama" --inference-model=llama3.2-vision:latest llama_stack/providers/tests/inference/test_vision_inference.py /Users/leseb/Documents/AI/llama-stack/.venv/lib/python3.13/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) ============================================ test session starts ============================================= platform darwin -- Python 3.13.1, pytest-8.3.4, pluggy-1.5.0 -- /Users/leseb/Documents/AI/llama-stack/.venv/bin/python3 cachedir: .pytest_cache metadata: {'Python': '3.13.1', 'Platform': 'macOS-15.3-arm64-arm-64bit-Mach-O', 'Packages': {'pytest': '8.3.4', 'pluggy': '1.5.0'}, 'Plugins': {'html': '4.1.1', 'metadata': '3.1.1', 'asyncio': '0.25.3', 'anyio': '4.8.0', 'nbval': '0.11.0'}} rootdir: /Users/leseb/Documents/AI/llama-stack configfile: pyproject.toml plugins: html-4.1.1, metadata-3.1.1, asyncio-0.25.3, anyio-4.8.0, nbval-0.11.0 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None collected 39 items / 36 deselected / 3 selected llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-ollama-image0-expected_strings0] PASSED llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_non_streaming[-ollama-image1-expected_strings1] PASSED llama_stack/providers/tests/inference/test_vision_inference.py::TestVisionModelInference::test_vision_chat_completion_streaming[-ollama] PASSED ========================== 3 passed, 36 deselected, 2 warnings in 62.23s (0:01:02) ========================== ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) Signed-off-by: Sébastien Han --- .../providers/tests/inference/test_vision_inference.py | 4 ++-- tests/client-sdk/inference/test_vision_inference.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/llama_stack/providers/tests/inference/test_vision_inference.py b/llama_stack/providers/tests/inference/test_vision_inference.py index a2434ac41..2f96e66d4 100644 --- a/llama_stack/providers/tests/inference/test_vision_inference.py +++ b/llama_stack/providers/tests/inference/test_vision_inference.py @@ -39,7 +39,7 @@ class TestVisionModelInference: ImageContentItem( image=dict( url=URL( - uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + uri="https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/client-sdk/inference/dog.png" ) ) ), @@ -80,7 +80,7 @@ class TestVisionModelInference: ImageContentItem( image=dict( url=URL( - uri="https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + uri="https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/client-sdk/inference/dog.png" ) ) ), diff --git a/tests/client-sdk/inference/test_vision_inference.py b/tests/client-sdk/inference/test_vision_inference.py index df4b9d933..b23089747 100644 --- a/tests/client-sdk/inference/test_vision_inference.py +++ b/tests/client-sdk/inference/test_vision_inference.py @@ -43,8 +43,7 @@ def test_image_chat_completion_non_streaming(llama_stack_client, vision_model_id "type": "image", "image": { "url": { - # TODO: Replace with Github based URI to resources/sample1.jpg - "uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + "uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/client-sdk/inference/dog.png" }, }, }, @@ -72,8 +71,7 @@ def test_image_chat_completion_streaming(llama_stack_client, vision_model_id): "type": "image", "image": { "url": { - # TODO: Replace with Github based URI to resources/sample1.jpg - "uri": "https://www.healthypawspetinsurance.com/Images/V3/DogAndPuppyInsurance/Dog_CTA_Desktop_HeroImage.jpg" + "uri": "https://raw.githubusercontent.com/meta-llama/llama-stack/main/tests/client-sdk/inference/dog.png" }, }, }, From d954f2752e386634f861e0447e9998e0ab15d15c Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 11 Feb 2025 00:20:50 -0500 Subject: [PATCH 02/27] fix: Added missing `tool_config` arg in SambaNova `chat_completion()` (#1042) # What does this PR do? `tool_config` is missing from the signature but is used in `ChatCompletionRequest()`. ## Test Plan This is a small fix. I don't have SambaNova to test the change but I doubt that this is currently working. Signed-off-by: Yuan Tang --- llama_stack/providers/remote/inference/sambanova/sambanova.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index dd697cd62..87aab1e88 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -116,6 +116,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): tool_choice: Optional[ToolChoice] = ToolChoice.auto, tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json, stream: Optional[bool] = False, + tool_config: Optional[ToolConfig] = None, logprobs: Optional[LogProbConfig] = None, ) -> AsyncGenerator: model = await self.model_store.get_model(model_id) From d947ddd2554476eb3ee2d8e6e6505db7f4606cc5 Mon Sep 17 00:00:00 2001 From: Kelly Brown <86735520+kelbrown20@users.noreply.github.com> Date: Tue, 11 Feb 2025 09:53:26 -0500 Subject: [PATCH 03/27] docs: Updating wording and nits in the README.md (#992) # What does this PR do? Fixing some wording nits and added small formatting suggestions in the README.md ## Before submitting - [x] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [x] Ran pre-commit to handle lint / formatting issues. - [x] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [x] Updated relevant documentation. - [ ] Wrote necessary unit or integration tests. --- README.md | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index a5e5b217d..baec8c1bd 100644 --- a/README.md +++ b/README.md @@ -7,13 +7,13 @@ [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) -Llama Stack defines and standardizes the core building blocks that simplify AI application development. It codified best practices across the Llama ecosystem. More specifically, it provides +Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides - **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry. -- **Plugin architecture** to support the rich ecosystem of implementations of the different APIs in different environments like local development, on-premises, cloud, and mobile. -- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment -- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android -- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack +- **Plugin architecture** to support the rich ecosystem of different API implementations in various environments, including local development, on-premises, cloud, and mobile. +- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment. +- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android. +- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack.
### Llama Stack Benefits -- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choice. -- **Consistent Experience**: With its unified APIs Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior. +- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choices. +- **Consistent Experience**: With its unified APIs, Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior. - **Robust Ecosystem**: Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies) that offer tailored infrastructure, software, and services for deploying Llama models. By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications. ### API Providers -Here is a list of the various API providers and available distributions to developers started easily, +Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack. | **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** | |:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:| @@ -71,15 +71,15 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider You have two ways to install this repository: -1. **Install as a package**: +* **Install as a package**: You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command: ```bash pip install llama-stack ``` -2. **Install from source**: +* **Install from source**: If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable). - Then, follow these steps: + Then, run the following commands: ```bash mkdir -p ~/local cd ~/local @@ -96,10 +96,11 @@ You have two ways to install this repository: Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details. -* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html) - * Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution. -* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) - * Quick guide to start a Llama Stack server. +* CLI references + * [llama (server-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html): Guide for using the `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution. + * [llama (client-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_stack_client_cli_reference.html): Guide for using the `llama-stack-client` CLI, which allows you to query information about the distribution. +* Getting Started + * [Quick guide to start a Llama Stack server](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html). * [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs * The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack). * A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples. @@ -115,6 +116,6 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest | Typescript | [llama-stack-client-typescript](https://github.com/meta-llama/llama-stack-client-typescript) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) | Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin) -Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. +Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo. From 71cae67d7b0ab1248f7c29ac483a65cca11440f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 11 Feb 2025 19:24:53 +0100 Subject: [PATCH 04/27] docs: remove changelog mention from PR template (#1049) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? The CHANGELOG.md was removed in https://github.com/meta-llama/llama-stack/commit/e6c9f2a4856192d6cb57a038d98d21a253c4319a so this mention is not relevant anymore. Signed-off-by: Sébastien Han Signed-off-by: Sébastien Han --- .github/PULL_REQUEST_TEMPLATE.md | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 044518abf..af2058b9a 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -8,4 +8,3 @@ [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) -[//]: # (- [ ] Added a Changelog entry if the change is significant) From 6ad272927dbf1292eb36cf6f08a5185dcd38515f Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 11 Feb 2025 17:07:26 -0500 Subject: [PATCH 05/27] docs: reflect actual number of spaces for indent (#1052) For what I see, it's all 4 spaces (as it should be for pep8[1]). [1] https://peps.python.org/pep-0008/#indentation # What does this PR do? Reflect indent reality. --- CONTRIBUTING.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 92939b47a..8028c194e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -98,7 +98,8 @@ $ uv sync ``` ## Coding Style -* 2 spaces for indentation rather than tabs + +* 4 spaces for indentation rather than tabs * 80 character line length * ... From 96c88397da50dc54f4354947aaec8390153763d2 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Tue, 11 Feb 2025 14:48:42 -0800 Subject: [PATCH 06/27] fix: agent config validation (#1053) Summary: Fixes AgentConfig init bug introduced with ToolConfig. Namely, the below doesn't work ``` agent_config = AgentConfig( **common_params, tool_config=ToolConfig( tool_choice="required", ), ) ``` bvecause tool_choice was defaulted to 'auto' leading to validation check failing. Test Plan: added unittests LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B --- llama_stack/apis/agents/agents.py | 14 +++-- .../inline/agents/meta_reference/agents.py | 6 -- tests/client-sdk/agents/test_agents.py | 59 ++++++++++++++++++- 3 files changed, 66 insertions(+), 13 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 95107d99f..785248633 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel): output_shields: Optional[List[str]] = Field(default_factory=list) toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list) client_tools: Optional[List[ToolDef]] = Field(default_factory=list) - tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead") + tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead") tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead") tool_config: Optional[ToolConfig] = Field(default=None) @@ -166,11 +166,13 @@ class AgentConfigCommon(BaseModel): raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.") if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format: raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.") - if self.tool_config is None: - self.tool_config = ToolConfig( - tool_choice=self.tool_choice, - tool_prompt_format=self.tool_prompt_format, - ) + else: + params = {} + if self.tool_choice: + params["tool_choice"] = self.tool_choice + if self.tool_prompt_format: + params["tool_prompt_format"] = self.tool_prompt_format + self.tool_config = ToolConfig(**params) @json_schema_type diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index 8f9fa2d82..fe4ccd1a3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -81,12 +81,6 @@ class MetaReferenceAgentsImpl(Agents): ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) - if agent_config.tool_config is None: - agent_config.tool_config = ToolConfig( - tool_choice=agent_config.tool_choice, - tool_prompt_format=agent_config.tool_prompt_format, - ) - await self.persistence_store.set( key=f"agent:{agent_id}", value=agent_config.model_dump_json(), diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 85b7af831..d14a7003f 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -13,11 +13,12 @@ from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage -from llama_stack_client.types.agent_create_params import AgentConfig +from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage from llama_stack_client.types.tool_def_param import Parameter +from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig, ToolChoice class TestClientTool(ClientTool): @@ -141,6 +142,62 @@ def test_agent_simple(llama_stack_client, agent_config): assert "I can't" in logs_str +def test_tool_config(llama_stack_client, agent_config): + common_params = dict( + model="meta-llama/Llama-3.2-3B-Instruct", + instructions="You are a helpful assistant", + sampling_params={ + "strategy": { + "type": "top_p", + "temperature": 1.0, + "top_p": 0.9, + }, + }, + toolgroups=[], + enable_session_persistence=False, + ) + agent_config = AgentConfig( + **common_params, + ) + Server__AgentConfig(**agent_config) + + agent_config = AgentConfig( + **common_params, + tool_choice="auto", + ) + server_config = Server__AgentConfig(**agent_config) + assert server_config.tool_config.tool_choice == ToolChoice.auto + + agent_config = AgentConfig( + **common_params, + tool_choice="auto", + tool_config=ToolConfig( + tool_choice="auto", + ), + ) + server_config = Server__AgentConfig(**agent_config) + assert server_config.tool_config.tool_choice == ToolChoice.auto + + agent_config = AgentConfig( + **common_params, + tool_config=ToolConfig( + tool_choice="required", + ), + ) + server_config = Server__AgentConfig(**agent_config) + assert server_config.tool_config.tool_choice == ToolChoice.required + + agent_config = AgentConfig( + **common_params, + tool_choice="required", + tool_config=ToolConfig( + tool_choice="auto", + ), + ) + with pytest.raises(ValueError, match="tool_choice is deprecated"): + Server__AgentConfig(**agent_config) + + def test_builtin_tool_web_search(llama_stack_client, agent_config): agent_config = { **agent_config, From ab7f802698b3bf712b5a7ecb2cf043d4a6384668 Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 11 Feb 2025 14:58:12 -0800 Subject: [PATCH 07/27] feat: add MetricResponseMixin to chat completion response types (#1050) # What does this PR do? Defines a MetricResponseMixin which can be inherited by any response class. Adds it to chat completion response types. This is a short term solution to allow inference API to return metrics The ideal way to do this is to have a way for all response types to include metrics and all metric events logged to the telemetry API to be included with the response To do this, we will need to augment all response types with a metrics field. We have hit a blocker from stainless SDK that prevents us from doing this. The blocker is that if we were to augment the response types that have a data field in them like so class ListModelsResponse(BaseModel): metrics: Optional[List[MetricEvent]] = None data: List[Models] ... The client SDK will need to access the data by using a .data field, which is not ergonomic. Stainless SDK does support unwrapping the response type, but it requires that the response type to only have a single field. We will need a way in the client SDK to signal that the metrics are needed and if they are needed, the client SDK has to return the full response type without unwrapping it. ## Test Plan sh run_openapi_generator.sh ./ sh stainless_sync.sh dineshyv/dev add-metrics-to-resp-v4 LLAMA_STACK_CONFIG="/Users/dineshyv/.llama/distributions/fireworks/fireworks-run.yaml" pytest -v tests/client-sdk/agents/test_agents.py --- docs/_static/llama-stack-spec.html | 154 +++++++++++++----------- docs/_static/llama-stack-spec.yaml | 90 +++++++------- llama_stack/apis/inference/inference.py | 7 +- llama_stack/apis/telemetry/telemetry.py | 26 +++- 4 files changed, 161 insertions(+), 116 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 151ac1451..75e0c4dfa 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3106,6 +3106,12 @@ "ChatCompletionResponse": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricEvent" + } + }, "completion_message": { "$ref": "#/components/schemas/CompletionMessage", "description": "The complete response message" @@ -3124,6 +3130,77 @@ ], "description": "Response from a chat completion request." }, + "MetricEvent": { + "type": "object", + "properties": { + "trace_id": { + "type": "string" + }, + "span_id": { + "type": "string" + }, + "timestamp": { + "type": "string", + "format": "date-time" + }, + "attributes": { + "type": "object", + "additionalProperties": { + "oneOf": [ + { + "type": "null" + }, + { + "type": "boolean" + }, + { + "type": "number" + }, + { + "type": "string" + }, + { + "type": "array" + }, + { + "type": "object" + } + ] + } + }, + "type": { + "type": "string", + "const": "metric", + "default": "metric" + }, + "metric": { + "type": "string" + }, + "value": { + "oneOf": [ + { + "type": "integer" + }, + { + "type": "number" + } + ] + }, + "unit": { + "type": "string" + } + }, + "additionalProperties": false, + "required": [ + "trace_id", + "span_id", + "timestamp", + "type", + "metric", + "value", + "unit" + ] + }, "TokenLogProbs": { "type": "object", "properties": { @@ -3388,6 +3465,12 @@ "ChatCompletionResponseStreamChunk": { "type": "object", "properties": { + "metrics": { + "type": "array", + "items": { + "$ref": "#/components/schemas/MetricEvent" + } + }, "event": { "$ref": "#/components/schemas/ChatCompletionResponseEvent", "description": "The event containing the new content" @@ -6374,77 +6457,6 @@ "critical" ] }, - "MetricEvent": { - "type": "object", - "properties": { - "trace_id": { - "type": "string" - }, - "span_id": { - "type": "string" - }, - "timestamp": { - "type": "string", - "format": "date-time" - }, - "attributes": { - "type": "object", - "additionalProperties": { - "oneOf": [ - { - "type": "null" - }, - { - "type": "boolean" - }, - { - "type": "number" - }, - { - "type": "string" - }, - { - "type": "array" - }, - { - "type": "object" - } - ] - } - }, - "type": { - "type": "string", - "const": "metric", - "default": "metric" - }, - "metric": { - "type": "string" - }, - "value": { - "oneOf": [ - { - "type": "integer" - }, - { - "type": "number" - } - ] - }, - "unit": { - "type": "string" - } - }, - "additionalProperties": false, - "required": [ - "trace_id", - "span_id", - "timestamp", - "type", - "metric", - "value", - "unit" - ] - }, "SpanEndPayload": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 37fba4541..c60a002e2 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1925,6 +1925,10 @@ components: ChatCompletionResponse: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricEvent' completion_message: $ref: '#/components/schemas/CompletionMessage' description: The complete response message @@ -1938,6 +1942,47 @@ components: required: - completion_message description: Response from a chat completion request. + MetricEvent: + type: object + properties: + trace_id: + type: string + span_id: + type: string + timestamp: + type: string + format: date-time + attributes: + type: object + additionalProperties: + oneOf: + - type: 'null' + - type: boolean + - type: number + - type: string + - type: array + - type: object + type: + type: string + const: metric + default: metric + metric: + type: string + value: + oneOf: + - type: integer + - type: number + unit: + type: string + additionalProperties: false + required: + - trace_id + - span_id + - timestamp + - type + - metric + - value + - unit TokenLogProbs: type: object properties: @@ -2173,6 +2218,10 @@ components: ChatCompletionResponseStreamChunk: type: object properties: + metrics: + type: array + items: + $ref: '#/components/schemas/MetricEvent' event: $ref: '#/components/schemas/ChatCompletionResponseEvent' description: The event containing the new content @@ -4070,47 +4119,6 @@ components: - warn - error - critical - MetricEvent: - type: object - properties: - trace_id: - type: string - span_id: - type: string - timestamp: - type: string - format: date-time - attributes: - type: object - additionalProperties: - oneOf: - - type: 'null' - - type: boolean - - type: number - - type: string - - type: array - - type: object - type: - type: string - const: metric - default: metric - metric: - type: string - value: - oneOf: - - type: integer - - type: number - unit: - type: string - additionalProperties: false - required: - - trace_id - - span_id - - timestamp - - type - - metric - - value - - unit SpanEndPayload: type: object properties: diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 4e095e831..9fccd3911 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -13,8 +13,8 @@ from typing import ( Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) from llama_models.llama3.api.datatypes import ( @@ -31,6 +31,7 @@ from typing_extensions import Annotated from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent from llama_stack.apis.models import Model +from llama_stack.apis.telemetry.telemetry import MetricResponseMixin from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -357,7 +358,7 @@ class ChatCompletionRequest(BaseModel): @json_schema_type -class ChatCompletionResponseStreamChunk(BaseModel): +class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel): """A chunk of a streamed chat completion response. :param event: The event containing the new content @@ -367,7 +368,7 @@ class ChatCompletionResponseStreamChunk(BaseModel): @json_schema_type -class ChatCompletionResponse(BaseModel): +class ChatCompletionResponse(MetricResponseMixin, BaseModel): """Response from a chat completion request. :param completion_message: The complete response message diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 324064007..6a62e274d 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -13,8 +13,8 @@ from typing import ( Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) from llama_models.schema_utils import json_schema_type, register_schema, webmethod @@ -94,6 +94,30 @@ class MetricEvent(EventCommon): unit: str +# This is a short term solution to allow inference API to return metrics +# The ideal way to do this is to have a way for all response types to include metrics +# and all metric events logged to the telemetry API to be inlcuded with the response +# To do this, we will need to augment all response types with a metrics field. +# We have hit a blocker from stainless SDK that prevents us from doing this. +# The blocker is that if we were to augment the response types that have a data field +# in them like so +# class ListModelsResponse(BaseModel): +# metrics: Optional[List[MetricEvent]] = None +# data: List[Models] +# ... +# The client SDK will need to access the data by using a .data field, which is not +# ergonomic. Stainless SDK does support unwrapping the response type, but it +# requires that the response type to only have a single field. + +# We will need a way in the client SDK to signal that the metrics are needed +# and if they are needed, the client SDK has to return the full response type +# without unwrapping it. + + +class MetricResponseMixin(BaseModel): + metrics: Optional[List[MetricEvent]] = None + + @json_schema_type class StructuredLogType(Enum): SPAN_START = "span_start" From d8a20e034b5fff1c2fa1a32d2c0313020d64a08c Mon Sep 17 00:00:00 2001 From: Dinesh Yeduguru Date: Tue, 11 Feb 2025 15:10:17 -0800 Subject: [PATCH 08/27] feat: make telemetry attributes be dict[str,PrimitiveType] (#1055) # What does this PR do? Make attributes in telemetry be only primitive types and avoid arbitrary nesting. ## Test Plan ``` LLAMA_STACK_DISABLE_VERSION_CHECK=true llama stack run ~/.llama/distributions/fireworks/fireworks-run.yaml LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/test_agents.py -k "test_builtin_tool_web_search" # Verified that attributes still show up correclty in jaeger ``` --- docs/_static/llama-stack-spec.html | 36 +++++++------------ docs/_static/llama-stack-spec.yaml | 28 +++++++-------- llama_stack/apis/telemetry/telemetry.py | 3 +- .../utils/telemetry/trace_protocol.py | 7 ++-- 4 files changed, 29 insertions(+), 45 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 75e0c4dfa..98270f7b8 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -3148,22 +3148,19 @@ "additionalProperties": { "oneOf": [ { - "type": "null" + "type": "string" }, { - "type": "boolean" + "type": "integer" }, { "type": "number" }, { - "type": "string" + "type": "boolean" }, { - "type": "array" - }, - { - "type": "object" + "type": "null" } ] } @@ -3683,8 +3680,7 @@ "auto", "required" ], - "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model.", - "default": "auto" + "description": "Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model." }, "tool_prompt_format": { "type": "string", @@ -6514,22 +6510,19 @@ "additionalProperties": { "oneOf": [ { - "type": "null" + "type": "string" }, { - "type": "boolean" + "type": "integer" }, { "type": "number" }, { - "type": "string" + "type": "boolean" }, { - "type": "array" - }, - { - "type": "object" + "type": "null" } ] } @@ -6587,22 +6580,19 @@ "additionalProperties": { "oneOf": [ { - "type": "null" + "type": "string" }, { - "type": "boolean" + "type": "integer" }, { "type": "number" }, { - "type": "string" + "type": "boolean" }, { - "type": "array" - }, - { - "type": "object" + "type": "null" } ] } diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index c60a002e2..a646d7e08 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -1956,12 +1956,11 @@ components: type: object additionalProperties: oneOf: - - type: 'null' - - type: boolean - - type: number - type: string - - type: array - - type: object + - type: integer + - type: number + - type: boolean + - type: 'null' type: type: string const: metric @@ -2387,7 +2386,6 @@ components: Whether tool use is required or automatic. This is a hint to the model which may not be followed. It depends on the Instruction Following capabilities of the model. - default: auto tool_prompt_format: type: string enum: @@ -4161,12 +4159,11 @@ components: type: object additionalProperties: oneOf: - - type: 'null' - - type: boolean - - type: number - type: string - - type: array - - type: object + - type: integer + - type: number + - type: boolean + - type: 'null' type: type: string const: structured_log @@ -4203,12 +4200,11 @@ components: type: object additionalProperties: oneOf: - - type: 'null' - - type: boolean - - type: number - type: string - - type: array - - type: object + - type: integer + - type: number + - type: boolean + - type: 'null' type: type: string const: unstructured_log diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 6a62e274d..6272cc40b 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -17,6 +17,7 @@ from typing import ( runtime_checkable, ) +from llama_models.llama3.api.datatypes import Primitive from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -76,7 +77,7 @@ class EventCommon(BaseModel): trace_id: str span_id: str timestamp: datetime - attributes: Optional[Dict[str, Any]] = Field(default_factory=dict) + attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict) @json_schema_type diff --git a/llama_stack/providers/utils/telemetry/trace_protocol.py b/llama_stack/providers/utils/telemetry/trace_protocol.py index 1d6988c1e..80c58a2c7 100644 --- a/llama_stack/providers/utils/telemetry/trace_protocol.py +++ b/llama_stack/providers/utils/telemetry/trace_protocol.py @@ -9,12 +9,13 @@ import inspect from functools import wraps from typing import Any, AsyncGenerator, Callable, Type, TypeVar +from llama_models.llama3.api.datatypes import Primitive from pydantic import BaseModel T = TypeVar("T") -def serialize_value(value: Any) -> Any: +def serialize_value(value: Any) -> Primitive: """Serialize a single value into JSON-compatible format.""" if value is None: return "" @@ -24,10 +25,6 @@ def serialize_value(value: Any) -> Any: return value._name_ elif isinstance(value, BaseModel): return value.model_dump_json() - elif isinstance(value, (list, tuple, set)): - return [serialize_value(item) for item in value] - elif isinstance(value, dict): - return {str(k): serialize_value(v) for k, v in value.items()} else: return str(value) From 24385cfd03e75ce85ef10d61d12a199036fc0852 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Tue, 11 Feb 2025 19:12:46 -0500 Subject: [PATCH 09/27] fix: filter out remote::sample providers when listing (#1057) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Before: ```  llama stack list-providers agents +------------------------+-----------------------------------------------------------------------+ | Provider Type | PIP Package Dependencies | +------------------------+-----------------------------------------------------------------------+ | inline::meta-reference | matplotlib,pillow,pandas,scikit-learn,aiosqlite,psycopg2-binary,redis | +------------------------+-----------------------------------------------------------------------+ | remote::sample | | +------------------------+-----------------------------------------------------------------------+ ``` After: ```  llama stack list-providers agents +------------------------+-----------------------------------------------------------------------+ | Provider Type | PIP Package Dependencies | +------------------------+-----------------------------------------------------------------------+ | inline::meta-reference | matplotlib,pillow,pandas,scikit-learn,aiosqlite,psycopg2-binary,redis | +------------------------+-----------------------------------------------------------------------+ ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Manually. [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka --- llama_stack/cli/stack/list_providers.py | 2 +- llama_stack/providers/datatypes.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index 909fea030..bd152c980 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -47,7 +47,7 @@ class StackListProviders(Subcommand): rows = [] for spec in providers_for_api.values(): - if spec.provider_type == "sample": + if spec.is_sample: continue rows.append( [ diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index d0c448f8c..8df91cce6 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -86,6 +86,10 @@ class ProviderSpec(BaseModel): # used internally by the resolver; this is a hack for now deps__: List[str] = Field(default_factory=list) + @property + def is_sample(self) -> bool: + return self.provider_type in ("sample", "remote::sample") + class RoutingTable(Protocol): def get_provider_impl(self, routing_key: str) -> Any: ... From dd37e588688ced330f1c58bfb4041dc9d7605eab Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Tue, 11 Feb 2025 21:08:29 -0500 Subject: [PATCH 10/27] feat: Support tool calling for non-streaming chat completion in remote vLLM provider (#1034) # What does this PR do? This PR adds support for tool calling for non-streaming chat completion. Prior to this, tool calls were not passed to chat completion requests and the tools object needs to be restructured properly to be compatible with vLLM provider. ## Test Plan ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py ================================================================= test session starts ================================================================= platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 12 items tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 8%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 16%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 25%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 33%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 41%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 58%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 66%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 75%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 83%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] FAILED [ 91%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [100%] ``` --------- Signed-off-by: Yuan Tang --- .../providers/remote/inference/vllm/vllm.py | 85 ++++++++++++++++++- .../utils/inference/openai_compat.py | 2 + 2 files changed, 84 insertions(+), 3 deletions(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 482e6fa97..8618abccf 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,10 +3,11 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import json import logging from typing import AsyncGenerator, List, Optional, Union +from llama_models.llama3.api import StopReason, ToolCall from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models @@ -30,6 +31,7 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, + CompletionMessage, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -40,7 +42,6 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, @@ -68,6 +69,73 @@ def build_model_aliases(): ] +def _convert_to_vllm_tool_calls_in_response( + tool_calls, +) -> List[ToolCall]: + if not tool_calls: + return [] + + call_function_arguments = None + for call in tool_calls: + call_function_arguments = json.loads(call.function.arguments) + + return [ + ToolCall( + call_id=call.id, + tool_name=call.function.name, + arguments=call_function_arguments, + ) + for call in tool_calls + ] + + +def _convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]: + if tools is None: + return tools + + compat_tools = [] + + for tool in tools: + properties = {} + compat_required = [] + if tool.parameters: + for tool_key, tool_param in tool.parameters.items(): + properties[tool_key] = {"type": tool_param.param_type} + if tool_param.description: + properties[tool_key]["description"] = tool_param.description + if tool_param.default: + properties[tool_key]["default"] = tool_param.default + if tool_param.required: + compat_required.append(tool_key) + + compat_tool = { + "type": "function", + "function": { + "name": tool.tool_name, + "description": tool.description, + "parameters": { + "type": "object", + "properties": properties, + "required": compat_required, + }, + }, + } + + compat_tools.append(compat_tool) + + if len(compat_tools) > 0: + return compat_tools + return None + + +def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: + return { + "stop": StopReason.end_of_turn, + "length": StopReason.out_of_tokens, + "tool_calls": StopReason.end_of_message, + }.get(finish_reason, StopReason.end_of_turn) + + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) @@ -142,7 +210,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): ) -> ChatCompletionResponse: params = await self._get_params(request) r = client.chat.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + choice = r.choices[0] + result = ChatCompletionResponse( + completion_message=CompletionMessage( + content=choice.message.content or "", + stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), + tool_calls=_convert_to_vllm_tool_calls_in_response(choice.message.tool_calls), + ), + logprobs=None, + ) + return result async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = await self._get_params(request) @@ -193,6 +270,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): options["max_tokens"] = self.config.max_tokens input_dict = {} + if isinstance(request, ChatCompletionRequest) and request.tools is not None: + input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)} if isinstance(request, ChatCompletionRequest): input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages] diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index a3e893d8f..8ee838d84 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -174,6 +174,8 @@ def process_chat_completion_response( ) -> ChatCompletionResponse: choice = response.choices[0] + # TODO: This does not work well with tool calls for vLLM remote provider + # Ref: https://github.com/meta-llama/llama-stack/issues/1058 raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) ) From 66d7e15c93b2ae0c1de075abcc5fe019987ef111 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Tue, 11 Feb 2025 18:31:35 -0800 Subject: [PATCH 11/27] perf: ensure ToolCall in ChatCompletionResponse is subset of ChatCompletionRequest.tools (#1041) # What does this PR do? **Problem** - Using script: https://gist.github.com/thoraxe/6163b2145ce7b1c24c6026b64cf90085 - This hits an issue on server with `code_interpreter` not found, as we do not pass "builtin::code_interpreter" in AgentConfig's `toolgroups`. This is a general issue where model always tries to output `code_interpreter` in `ToolCall` even when we do not have `code_interpreter` available for execution. **Reproduce Deeper Problem in chat-completion** - Use script: https://gist.github.com/yanxi0830/163a9ad7b5db10556043fbfc7ecd7603 1. We currently always populate `code_interpreter` in `ToolCall` in ChatCompletionResponse if the model's response begins with `<|python_tag|>`. See https://github.com/meta-llama/llama-models/blob/c5f59584982e6f1c5ce2dd5a9d2a5763891ec276/models/llama3/api/chat_format.py#L200-L213 image 2. This happens even if we do not pass the `code_interpreter` as a `tools` in ChatCompletionRequest. **This PR** Explicitly make sure that the tools returned in `ChatCompletionResponse.tool_calls` is always a tool requested by `ChatCompletionRequest.tools`. [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan **Before** image image **After** image image **Unit Test** ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request --inference-model "meta-llama/Llama-3.3-70B-Instruct" ``` ``` LLAMA_STACK_BASE_URL=http://localhost:8321 pytest -v tests/client-sdk/agents/ ``` image **Streaming** - Chat Completion image - Agent image [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) --- .../agents/meta_reference/agent_instance.py | 3 + .../providers/inline/inference/vllm/vllm.py | 4 +- .../remote/inference/bedrock/bedrock.py | 4 +- .../remote/inference/cerebras/cerebras.py | 4 +- .../remote/inference/databricks/databricks.py | 4 +- .../remote/inference/fireworks/fireworks.py | 4 +- .../remote/inference/ollama/ollama.py | 4 +- .../remote/inference/runpod/runpod.py | 4 +- .../remote/inference/sambanova/sambanova.py | 2 +- .../providers/remote/inference/tgi/tgi.py | 4 +- .../remote/inference/together/together.py | 4 +- .../providers/remote/inference/vllm/vllm.py | 2 +- .../utils/inference/openai_compat.py | 70 +++++++++++++--- .../inference/test_text_inference.py | 84 ++++++++++++++++++- 14 files changed, 164 insertions(+), 33 deletions(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 51691c546..2f397f438 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -513,6 +513,9 @@ class ChatAgent(ShieldRunnerMixin): if delta.type == "tool_call": if delta.parse_status == ToolCallParseStatus.succeeded: tool_calls.append(delta.tool_call) + elif delta.parse_status == ToolCallParseStatus.failed: + # If we cannot parse the tools, set the content to the unparsed raw text + content = delta.tool_call if stream: yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 691737c15..77c95cc7e 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -201,7 +201,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion( self, request: ChatCompletionRequest, results_generator: AsyncGenerator @@ -227,7 +227,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def embeddings(self, model_id: str, contents: List[InterleavedContent]) -> EmbeddingsResponse: diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 03a0a40c3..54a674d7e 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -134,7 +134,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): ) response = OpenAICompatCompletionResponse(choices=[choice]) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params_for_chat_completion(request) @@ -152,7 +152,7 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference): yield OpenAICompatCompletionResponse(choices=[choice]) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params_for_chat_completion(self, request: ChatCompletionRequest) -> Dict: diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index bd12c56c8..47f208129 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -155,14 +155,14 @@ class CerebrasInferenceAdapter(ModelRegistryHelper, Inference): r = await self.client.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: CompletionRequest) -> AsyncGenerator: params = await self._get_params(request) stream = await self.client.completions.create(**params) - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index 37070b4ce..ee3c6e99b 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -112,7 +112,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -123,7 +123,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index d47c035b8..d978cb02e 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -230,7 +230,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv r = await self._get_client().chat.completions.acreate(**params) else: r = await self._get_client().completion.acreate(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -244,7 +244,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index ecd195854..05a5d2d7a 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -304,7 +304,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -330,7 +330,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index a62b0c97f..c7b20b9a1 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -99,7 +99,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): ) -> ChatCompletionResponse: params = self._get_params(request) r = client.completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator: params = self._get_params(request) @@ -110,7 +110,7 @@ class RunpodInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk def _get_params(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 87aab1e88..18a78e69c 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -160,7 +160,7 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def embeddings( diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 2281319b3..97a6621fb 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -236,7 +236,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): response = OpenAICompatCompletionResponse( choices=[choice], ) - return process_chat_completion_response(response, self.formatter) + return process_chat_completion_response(response, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -252,7 +252,7 @@ class _HfAdapter(Inference, ModelsProtocolPrivate): ) stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: ChatCompletionRequest) -> dict: diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index cf24daf60..a165b01d9 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -220,7 +220,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi r = self._get_client().chat.completions.create(**params) else: r = self._get_client().completions.create(**params) - return process_chat_completion_response(r, self.formatter) + return process_chat_completion_response(r, self.formatter, request) async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: params = await self._get_params(request) @@ -235,7 +235,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 8618abccf..2e13a6262 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -232,7 +232,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter): + async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 8ee838d84..1047c9a58 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,7 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. - +import logging from typing import AsyncGenerator, Dict, List, Optional, Union from llama_models.datatypes import ( @@ -26,6 +26,7 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.inference import ( + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseEvent, ChatCompletionResponseEventType, @@ -41,6 +42,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) +logger = logging.getLogger(__name__) + class OpenAICompatCompletionChoiceDelta(BaseModel): content: str @@ -170,7 +173,9 @@ def process_completion_response(response: OpenAICompatCompletionResponse, format def process_chat_completion_response( - response: OpenAICompatCompletionResponse, formatter: ChatFormat + response: OpenAICompatCompletionResponse, + formatter: ChatFormat, + request: ChatCompletionRequest, ) -> ChatCompletionResponse: choice = response.choices[0] @@ -179,6 +184,28 @@ def process_chat_completion_response( raw_message = formatter.decode_assistant_message_from_content( text_from_choice(choice), get_stop_reason(choice.finish_reason) ) + + # NOTE: If we do not set tools in chat-completion request, we should not + # expect the ToolCall in the response. Instead, we should return the raw + # response from the model. + if raw_message.tool_calls: + if not request.tools: + raw_message.tool_calls = [] + raw_message.content = text_from_choice(choice) + else: + # only return tool_calls if provided in the request + new_tool_calls = [] + request_tools = {t.tool_name: t for t in request.tools} + for t in raw_message.tool_calls: + if t.tool_name in request_tools: + new_tool_calls.append(t) + else: + logger.warning(f"Tool {t.tool_name} not found in request tools") + + if len(new_tool_calls) < len(raw_message.tool_calls): + raw_message.tool_calls = new_tool_calls + raw_message.content = text_from_choice(choice) + return ChatCompletionResponse( completion_message=CompletionMessage( content=raw_message.content, @@ -226,7 +253,9 @@ async def process_completion_stream_response( async def process_chat_completion_stream_response( - stream: AsyncGenerator[OpenAICompatCompletionResponse, None], formatter: ChatFormat + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], + formatter: ChatFormat, + request: ChatCompletionRequest, ) -> AsyncGenerator: yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -305,6 +334,7 @@ async def process_chat_completion_stream_response( # parse tool calls and report errors message = formatter.decode_assistant_message_from_content(buffer, stop_reason) + parsed_tool_calls = len(message.tool_calls) > 0 if ipython and not parsed_tool_calls: yield ChatCompletionResponseStreamChunk( @@ -318,17 +348,33 @@ async def process_chat_completion_stream_response( ) ) + request_tools = {t.tool_name: t for t in request.tools} for tool_call in message.tool_calls: - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta( - tool_call=tool_call, - parse_status=ToolCallParseStatus.succeeded, - ), - stop_reason=stop_reason, + if tool_call.tool_name in request_tools: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + tool_call=tool_call, + parse_status=ToolCallParseStatus.succeeded, + ), + stop_reason=stop_reason, + ) + ) + else: + logger.warning(f"Tool {tool_call.tool_name} not found in request tools") + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.progress, + delta=ToolCallDelta( + # Parsing tool call failed due to tool call not being found in request tools, + # We still add the raw message text inside tool_call for responding back to the user + tool_call=buffer, + parse_status=ToolCallParseStatus.failed, + ), + stop_reason=stop_reason, + ) ) - ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 81b476218..206629602 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -158,7 +158,10 @@ def test_text_completion_structured_output(llama_stack_client, text_model_id, in "question,expected", [ ("Which planet do humans live on?", "Earth"), - ("Which planet has rings around it with a name starting with letter S?", "Saturn"), + ( + "Which planet has rings around it with a name starting with letter S?", + "Saturn", + ), ], ) def test_text_chat_completion_non_streaming(llama_stack_client, text_model_id, question, expected): @@ -280,3 +283,82 @@ def test_text_chat_completion_structured_output(llama_stack_client, text_model_i assert answer.last_name == "Jordan" assert answer.year_of_birth == 1963 assert answer.num_seasons_in_nba == 15 + + +@pytest.mark.parametrize( + "streaming", + [ + True, + False, + ], +) +def test_text_chat_completion_tool_calling_tools_not_in_request(llama_stack_client, text_model_id, streaming): + # TODO: more dynamic lookup on tool_prompt_format for model family + tool_prompt_format = "json" if "3.1" in text_model_id else "python_list" + request = { + "model_id": text_model_id, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "What pods are in the namespace openshift-lightspeed?", + }, + { + "role": "assistant", + "content": "", + "stop_reason": "end_of_turn", + "tool_calls": [ + { + "call_id": "1", + "tool_name": "get_object_namespace_list", + "arguments": { + "kind": "pod", + "namespace": "openshift-lightspeed", + }, + } + ], + }, + { + "role": "tool", + "call_id": "1", + "tool_name": "get_object_namespace_list", + "content": "the objects are pod1, pod2, pod3", + }, + ], + "tools": [ + { + "tool_name": "get_object_namespace_list", + "description": "Get the list of objects in a namespace", + "parameters": { + "kind": { + "param_type": "string", + "description": "the type of object", + "required": True, + }, + "namespace": { + "param_type": "string", + "description": "the name of the namespace", + "required": True, + }, + }, + } + ], + "tool_choice": "auto", + "tool_prompt_format": tool_prompt_format, + "stream": streaming, + } + + response = llama_stack_client.inference.chat_completion(**request) + + if streaming: + for chunk in response: + delta = chunk.event.delta + if delta.type == "tool_call" and delta.parse_status == "succeeded": + assert delta.tool_call.tool_name == "get_object_namespace_list" + if delta.type == "tool_call" and delta.parse_status == "failed": + # expect raw message that failed to parse in tool_call + assert type(delta.tool_call) == str + assert len(delta.tool_call) > 0 + else: + for tc in response.completion_message.tool_calls: + assert tc.tool_name == "get_object_namespace_list" From bf11cc0450722ac7ec728f0a57f3388545ce4c8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 12 Feb 2025 07:10:28 +0100 Subject: [PATCH 12/27] chore: update return type to Optional[str] (#982) --- .../providers/utils/inference/model_registry.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index 5746af4ba..dea951395 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -57,17 +57,11 @@ class ModelRegistryHelper(ModelsProtocolPrivate): self.alias_to_provider_id_map[alias_obj.llama_model] = alias_obj.provider_model_id self.provider_id_to_llama_model_map[alias_obj.provider_model_id] = alias_obj.llama_model - def get_provider_model_id(self, identifier: str) -> str: - if identifier in self.alias_to_provider_id_map: - return self.alias_to_provider_id_map[identifier] - else: - return None + def get_provider_model_id(self, identifier: str) -> Optional[str]: + return self.alias_to_provider_id_map.get(identifier, None) - def get_llama_model(self, provider_model_id: str) -> str: - if provider_model_id in self.provider_id_to_llama_model_map: - return self.provider_id_to_llama_model_map[provider_model_id] - else: - return None + def get_llama_model(self, provider_model_id: str) -> Optional[str]: + return self.provider_id_to_llama_model_map.get(provider_model_id, None) async def register_model(self, model: Model) -> Model: if model.model_type == ModelType.embedding: From 5e97dd991932f7b625fe50d2a5dd6e830cec6346 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Wed, 12 Feb 2025 09:17:21 -0500 Subject: [PATCH 13/27] feat: Support tool calling for streaming chat completion in remote vLLM provider (#1063) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] Closes https://github.com/meta-llama/llama-stack/issues/1046. ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` LLAMA_STACK_BASE_URL=http://localhost:5002 pytest -v tests/client-sdk/inference/test_text_inference.py ================================================================= test session starts ================================================================= platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 14 items tests/client-sdk/inference/test_text_inference.py::test_text_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 7%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 14%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote:...) [ 21%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] XFAIL (remote::vll...) [ 28%] tests/client-sdk/inference/test_text_inference.py::test_text_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 35%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet do humans live on?-Earth] PASSED [ 42%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_non_streaming[meta-llama/Llama-3.1-8B-Instruct-Which planet has rings around it with a name starting with letter S?-Saturn] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What's the name of the Sun in latin?-Sol] PASSED [ 57%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_streaming[meta-llama/Llama-3.1-8B-Instruct-What is the name of the US captial?-Washington] PASSED [ 64%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 71%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_with_tool_calling_and_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 78%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_structured_output[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 85%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-True] PASSED [ 92%] tests/client-sdk/inference/test_text_inference.py::test_text_chat_completion_tool_calling_tools_not_in_request[meta-llama/Llama-3.1-8B-Instruct-False] PASSED [100%] =============================================== 12 passed, 2 xfailed, 1 warning in 366.56s (0:06:06) ================================================ ``` --------- Signed-off-by: Yuan Tang --- .../remote/inference/groq/groq_utils.py | 46 ++------------ .../providers/remote/inference/vllm/vllm.py | 61 ++++++++++++++++++- .../utils/inference/openai_compat.py | 39 +++++++++++- 3 files changed, 101 insertions(+), 45 deletions(-) diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index 537043d69..d00e5c5a9 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -6,7 +6,7 @@ import json import warnings -from typing import AsyncGenerator, Literal, Union +from typing import AsyncGenerator, Literal from groq import Stream from groq.types.chat.chat_completion import ChatCompletion @@ -15,9 +15,6 @@ from groq.types.chat.chat_completion_assistant_message_param import ( ) from groq.types.chat.chat_completion_chunk import ChatCompletionChunk from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam -from groq.types.chat.chat_completion_message_tool_call import ( - ChatCompletionMessageToolCall, -) from groq.types.chat.chat_completion_system_message_param import ( ChatCompletionSystemMessageParam, ) @@ -30,7 +27,6 @@ from groq.types.shared.function_definition import FunctionDefinition from llama_models.llama3.api.datatypes import ToolParamDefinition -from pydantic import BaseModel from llama_stack.apis.common.content_types import ( TextDelta, @@ -52,6 +48,8 @@ from llama_stack.apis.inference import ( ) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_strategy_options, + convert_tool_call, + UnparseableToolCall, ) @@ -143,7 +141,7 @@ def convert_chat_completion_response( # groq only supports n=1 at time of writing, so there is only one choice choice = response.choices[0] if choice.finish_reason == "tool_calls": - tool_calls = [_convert_groq_tool_call(tool_call) for tool_call in choice.message.tool_calls] + tool_calls = [convert_tool_call(tool_call) for tool_call in choice.message.tool_calls] if any(isinstance(tool_call, UnparseableToolCall) for tool_call in tool_calls): # If we couldn't parse a tool call, jsonify the tool calls and return them return ChatCompletionResponse( @@ -216,7 +214,7 @@ async def convert_chat_completion_response_stream( warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.") # We assume Groq produces fully formed tool calls for each chunk - tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0]) + tool_call = convert_tool_call(choice.delta.tool_calls[0]) if isinstance(tool_call, ToolCall): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( @@ -247,37 +245,3 @@ async def convert_chat_completion_response_stream( ) ) event_type = ChatCompletionResponseEventType.progress - - -class UnparseableToolCall(BaseModel): - """ - A ToolCall with arguments that are not valid JSON. - Mirrors the ToolCall schema, but with arguments as a string. - """ - - call_id: str - tool_name: str - arguments: str - - -def _convert_groq_tool_call( - tool_call: ChatCompletionMessageToolCall, -) -> Union[ToolCall, UnparseableToolCall]: - """ - Convert a Groq tool call to a ToolCall. - Returns an UnparseableToolCall if the tool call is not valid JSON. - """ - try: - arguments = json.loads(tool_call.function.arguments) - except Exception as e: - return UnparseableToolCall( - call_id=tool_call.id, - tool_name=tool_call.function.name, - arguments=tool_call.function.arguments, - ) - - return ToolCall( - call_id=tool_call.id, - tool_name=tool_call.function.name, - arguments=arguments, - ) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 2e13a6262..02594891b 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -13,7 +13,7 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent +from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -32,6 +32,9 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, CompletionMessage, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + ChatCompletionResponseEvent, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate @@ -42,9 +45,12 @@ from llama_stack.providers.utils.inference.model_registry import ( from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, get_sampling_options, - process_chat_completion_stream_response, process_completion_response, process_completion_stream_response, + OpenAICompatCompletionResponse, + UnparseableToolCall, + convert_tool_call, + process_chat_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( completion_request_to_prompt, @@ -136,6 +142,51 @@ def _convert_to_vllm_finish_reason(finish_reason: str) -> StopReason: }.get(finish_reason, StopReason.end_of_turn) +async def _process_vllm_chat_completion_stream_response( + stream: AsyncGenerator[OpenAICompatCompletionResponse, None], +) -> AsyncGenerator: + event_type = ChatCompletionResponseEventType.start + tool_call_buf = UnparseableToolCall() + async for chunk in stream: + choice = chunk.choices[0] + if choice.finish_reason: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=ToolCallDelta( + tool_call=ToolCall( + call_id=tool_call_buf.call_id, + tool_name=tool_call_buf.tool_name, + arguments=json.loads(tool_call_buf.arguments), + ), + parse_status=ToolCallParseStatus.succeeded, + ), + ) + ) + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=choice.delta.content or ""), + logprobs=None, + stop_reason=_convert_to_vllm_finish_reason(choice.finish_reason), + ) + ) + elif choice.delta.tool_calls: + tool_call = convert_tool_call(choice.delta.tool_calls[0]) + tool_call_buf.tool_name += tool_call.tool_name + tool_call_buf.call_id += tool_call.call_id + tool_call_buf.arguments += tool_call.arguments + else: + yield ChatCompletionResponseStreamChunk( + event=ChatCompletionResponseEvent( + event_type=event_type, + delta=TextDelta(text=choice.delta.content or ""), + logprobs=None, + ) + ) + event_type = ChatCompletionResponseEventType.progress + + class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_model_aliases()) @@ -232,7 +283,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk stream = _to_async_generator() - async for chunk in process_chat_completion_stream_response(stream, self.formatter, request): + if len(request.tools) > 0: + res = _process_vllm_chat_completion_stream_response(stream) + else: + res = process_chat_completion_stream_response(stream, self.formatter, request) + async for chunk in res: yield chunk async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 1047c9a58..7480ff2c7 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -3,6 +3,7 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import logging from typing import AsyncGenerator, Dict, List, Optional, Union @@ -14,7 +15,8 @@ from llama_models.datatypes import ( ) from llama_models.llama3.api.chat_format import ChatFormat -from llama_models.llama3.api.datatypes import StopReason +from llama_models.llama3.api.datatypes import StopReason, ToolCall +from openai.types.chat import ChatCompletionMessageToolCall from pydantic import BaseModel from llama_stack.apis.common.content_types import ( @@ -408,3 +410,38 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals "role": message.role, "content": content, } + + +class UnparseableToolCall(BaseModel): + """ + A ToolCall with arguments that are not valid JSON. + Mirrors the ToolCall schema, but with arguments as a string. + """ + + call_id: str = "" + tool_name: str = "" + arguments: str = "" + + +def convert_tool_call( + tool_call: ChatCompletionMessageToolCall, +) -> Union[ToolCall, UnparseableToolCall]: + """ + Convert a ChatCompletionMessageToolCall tool call to either a + ToolCall or UnparseableToolCall. Returns an UnparseableToolCall + if the tool call is not valid JSON. + """ + try: + arguments = json.loads(tool_call.function.arguments) + except Exception as e: + return UnparseableToolCall( + call_id=tool_call.id or "", + tool_name=tool_call.function.name or "", + arguments=tool_call.function.arguments or "", + ) + + return ToolCall( + call_id=tool_call.id, + tool_name=tool_call.function.name, + arguments=arguments, + ) From 5f88ff0b6a3dd2d5c216f0e50b7eaa14b3c318d2 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 12 Feb 2025 09:38:25 -0500 Subject: [PATCH 14/27] fix: show proper help text (#1065) # What does this PR do? when executing a sub-command like `llama model` the improper help text, sub-commands, and flags are displayed. each command group needs to have `.set_defaults` to display this info properly before: ``` llama model usage: llama [-h] {model,stack,download,verify-download} ... Welcome to the Llama CLI options: -h, --help show this help message and exit subcommands: {model,stack,download,verify-download} ``` after: ``` llama model usage: llama model [-h] {download,list,prompt-format,describe,verify-download} ... Work with llama models options: -h, --help show this help message and exit model_subcommands: {download,list,prompt-format,describe,verify-download} ``` Signed-off-by: Charlie Doern --- llama_stack/cli/model/model.py | 2 ++ llama_stack/cli/stack/stack.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/llama_stack/cli/model/model.py b/llama_stack/cli/model/model.py index f59ba8376..02e7f216f 100644 --- a/llama_stack/cli/model/model.py +++ b/llama_stack/cli/model/model.py @@ -26,6 +26,8 @@ class ModelParser(Subcommand): description="Work with llama models", ) + self.parser.set_defaults(func=lambda args: self.parser.print_help()) + subparsers = self.parser.add_subparsers(title="model_subcommands") # Add sub-commands diff --git a/llama_stack/cli/stack/stack.py b/llama_stack/cli/stack/stack.py index 8650bd728..10e49f8c9 100644 --- a/llama_stack/cli/stack/stack.py +++ b/llama_stack/cli/stack/stack.py @@ -31,6 +31,8 @@ class StackParser(Subcommand): version=f"{version('llama-stack')}", ) + self.parser.set_defaults(func=lambda args: self.parser.print_help()) + subparsers = self.parser.add_subparsers(title="stack_subcommands") # Add sub-commands From 025f6158684bf647a94213cb76a5d5b3b23735f4 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 12 Feb 2025 11:13:04 -0500 Subject: [PATCH 15/27] feat: add support for running in a venv (#1018) # What does this PR do? add --image-type to `llama stack run`. Which takes conda, container or venv also add start_venv.sh which start the stack using a venv resolves #1007 ## Test Plan running locally: `llama stack build --template ollama --image-type venv` `llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml` ... ``` llama stack run --image-type venv ~/.llama/distributions/ollama/ollama-run.yaml Using run configuration: /Users/charliedoern/.llama/distributions/ollama/ollama-run.yaml + python -m llama_stack.distribution.server.server --yaml-config /Users/charliedoern/.llama/distributions/ollama/ollama-run.yaml --port 8321 Using config file: /Users/charliedoern/.llama/distributions/ollama/ollama-run.yaml Run configuration: apis: - agents - datasetio ... ``` Signed-off-by: Charlie Doern --- docs/source/distributions/building_distro.md | 33 +++++++++ llama_stack/cli/stack/run.py | 20 +++++- llama_stack/distribution/start_venv.sh | 71 ++++++++++++++++++++ 3 files changed, 122 insertions(+), 2 deletions(-) create mode 100755 llama_stack/distribution/start_venv.sh diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index 5556d4aa1..90239cb4e 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -180,12 +180,45 @@ After this step is successful, you should be able to find the built container im ### Running your Stack server Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step. +``` +llama stack run -h +usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE] + [--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}] + config + +start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution. + +positional arguments: + config Path to config file to use for the run + +options: + -h, --help show this help message and exit + --port PORT Port to run the server on. Defaults to 8321 + --image-name IMAGE_NAME + Name of the image to run. Defaults to the current conda environment + --disable-ipv6 Disable IPv6 support + --env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. + --tls-keyfile TLS_KEYFILE + Path to TLS key file for HTTPS + --tls-certfile TLS_CERTFILE + Path to TLS certificate file for HTTPS + --image-type {conda,container,venv} + Image Type used during the build. This can be either conda or container or venv. + +``` + ``` # Start using template name llama stack run tgi # Start using config file llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml + +# Start using a venv +llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml + +# Start using a conda environment +llama stack run --image-type conda ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml ``` ``` diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index e7d6df292..c32e51fca 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -65,6 +65,13 @@ class StackRun(Subcommand): type=str, help="Path to TLS certificate file for HTTPS", ) + self.parser.add_argument( + "--image-type", + type=str, + help="Image Type used during the build. This can be either conda or container or venv.", + choices=["conda", "container", "venv"], + default="conda", + ) def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: import importlib.resources @@ -118,11 +125,11 @@ class StackRun(Subcommand): config_dict = yaml.safe_load(config_file.read_text()) config = parse_and_maybe_upgrade_config(config_dict) - if config.container_image: + if args.image_type == ImageType.container.value or config.container_image: script = importlib.resources.files("llama_stack") / "distribution/start_container.sh" image_name = f"distribution-{template_name}" if template_name else config.container_image run_args = [script, image_name] - else: + elif args.image_type == ImageType.conda.value: current_conda_env = os.environ.get("CONDA_DEFAULT_ENV") image_name = args.image_name or current_conda_env if not image_name: @@ -167,6 +174,15 @@ class StackRun(Subcommand): script, image_name, ] + else: + # else must be venv since that is the only valid option left. + current_venv = os.environ.get("VIRTUAL_ENV") + venv = args.image_name or current_venv + script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh" + run_args = [ + script, + venv, + ] run_args.extend([str(config_file), str(args.port)]) if args.disable_ipv6: diff --git a/llama_stack/distribution/start_venv.sh b/llama_stack/distribution/start_venv.sh new file mode 100755 index 000000000..1cfa7248f --- /dev/null +++ b/llama_stack/distribution/start_venv.sh @@ -0,0 +1,71 @@ +#!/bin/bash + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +set -euo pipefail + +RED='\033[0;31m' +NC='\033[0m' # No Color + +error_handler() { + echo "Error occurred in script at line: ${1}" >&2 + exit 1 +} + +trap 'error_handler ${LINENO}' ERR + +if [ $# -lt 3 ]; then + echo "Usage: $0 " + exit 1 +fi + +venv_path="$1" +shift + +yaml_config="$1" +shift + +port="$1" +shift + +# Initialize env_vars as an empty array +env_vars="" +other_args="" +# Process environment variables from --env arguments +while [[ $# -gt 0 ]]; do + case "$1" in + --env) + + if [[ -n "$2" ]]; then + env_vars="$env_vars --env $2" + shift 2 + else + echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2 + exit 1 + fi + ;; + *) + other_args="$other_args $1" + shift + ;; + esac +done + +# Activate virtual environment +if [ ! -d "$venv_path" ]; then + echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2 + exit 1 +fi + +source "$venv_path/bin/activate" + +set -x +python -m llama_stack.distribution.server.server \ + --yaml-config "$yaml_config" \ + --port "$port" \ + $env_vars \ + $other_args From 119fe8742a90e7f9d237af4655be046cd4fcd81a Mon Sep 17 00:00:00 2001 From: Francisco Arceo Date: Wed, 12 Feb 2025 13:50:03 -0500 Subject: [PATCH 16/27] feat: Adding sqlite-vec as a vectordb (#1040) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This PR adds `sqlite_vec` as an additional inline vectordb. Tested with `ollama` by adding the `vector_io` object in `./llama_stack/templates/ollama/run.yaml` : ```yaml vector_io: - provider_id: sqlite_vec provider_type: inline::sqlite_vec config: kvstore: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db ``` I also updated the `./tests/client-sdk/vector_io/test_vector_io.py` test file with: ```python INLINE_VECTOR_DB_PROVIDERS = ["faiss", "sqlite_vec"] ``` And parameterized the relevant tests. [//]: # (If resolving an issue, uncomment and update the line below) # Closes https://github.com/meta-llama/llama-stack/issues/1005 ## Test Plan I ran the tests with: ```bash INFERENCE_MODEL=llama3.2:3b-instruct-fp16 LLAMA_STACK_CONFIG=ollama pytest -s -v tests/client-sdk/vector_io/test_vector_io.py ``` Which outputs: ```python ... PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_retrieve[all-MiniLM-L6-v2-sqlite_vec] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_list PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-faiss] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_register[all-MiniLM-L6-v2-sqlite_vec] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[faiss] PASSED tests/client-sdk/vector_io/test_vector_io.py::test_vector_db_unregister[sqlite_vec] PASSED ``` In addition, I ran the `rag_with_vector_db.py` [example](https://github.com/meta-llama/llama-stack-apps/blob/main/examples/agents/rag_with_vector_db.py) using the script below with `uv run rag_example.py`.
CLICK TO SHOW SCRIPT 👋 ```python #!/usr/bin/env python3 import os import uuid from termcolor import cprint # Set environment variables os.environ['INFERENCE_MODEL'] = 'llama3.2:3b-instruct-fp16' os.environ['LLAMA_STACK_CONFIG'] = 'ollama' # Import libraries after setting environment variables from llama_stack.distribution.library_client import LlamaStackAsLibraryClient from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types import Document def main(): # Initialize the client client = LlamaStackAsLibraryClient("ollama") vector_db_id = f"test-vector-db-{uuid.uuid4().hex}" _ = client.initialize() model_id = 'llama3.2:3b-instruct-fp16' # Define the list of document URLs and create Document objects urls = [ "chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst", ] documents = [ Document( document_id=f"num-{i}", content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", mime_type="text/plain", metadata={}, ) for i, url in enumerate(urls) ] # (Optional) Use the documents as needed with your client here client.vector_dbs.register( provider_id='sqlite_vec', vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", embedding_dimension=384, ) client.tool_runtime.rag_tool.insert( documents=documents, vector_db_id=vector_db_id, chunk_size_in_tokens=512, ) # Create agent configuration agent_config = AgentConfig( model=model_id, instructions="You are a helpful assistant", enable_session_persistence=False, toolgroups=[ { "name": "builtin::rag", "args": { "vector_db_ids": [vector_db_id], } } ], ) # Instantiate the Agent agent = Agent(client, agent_config) # List of user prompts user_prompts = [ "What are the top 5 topics that were explained in the documentation? Only list succinct bullet points.", "Was anything related to 'Llama3' discussed, if so what?", "Tell me how to use LoRA", "What about Quantization?", ] # Create a session for the agent session_id = agent.create_session("test-session") # Process each prompt and display the output for prompt in user_prompts: cprint(f"User> {prompt}", "green") response = agent.create_turn( messages=[ { "role": "user", "content": prompt, } ], session_id=session_id, ) # Log and print events from the response for log in EventLogger().log(response): log.print() if __name__ == "__main__": main() ```
Which outputs a large summary of RAG generation. # Documentation Will handle documentation updates in follow-up PR. # (- [ ] Added a Changelog entry if the change is significant) --------- Signed-off-by: Francisco Javier Arceo --- .../inline/vector_io/sqlite_vec/__init__.py | 18 ++ .../inline/vector_io/sqlite_vec/config.py | 28 +++ .../inline/vector_io/sqlite_vec/sqlite_vec.py | 214 ++++++++++++++++++ llama_stack/providers/registry/vector_io.py | 8 + .../providers/tests/vector_io/conftest.py | 8 + .../providers/tests/vector_io/fixtures.py | 25 +- llama_stack/templates/ollama/build.yaml | 1 + llama_stack/templates/ollama/ollama.py | 12 +- llama_stack/templates/ollama/run.yaml | 8 + tests/client-sdk/vector_io/test_vector_io.py | 21 +- 10 files changed, 331 insertions(+), 12 deletions(-) create mode 100644 llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py create mode 100644 llama_stack/providers/inline/vector_io/sqlite_vec/config.py create mode 100644 llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py new file mode 100644 index 000000000..488a57660 --- /dev/null +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Dict +from llama_stack.providers.datatypes import Api, ProviderSpec +from .config import SQLiteVectorIOConfig + + +async def get_provider_impl(config: SQLiteVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .sqlite_vec import SQLiteVecVectorIOAdapter + + assert isinstance(config, SQLiteVectorIOConfig), f"Unexpected config type: {type(config)}" + impl = SQLiteVecVectorIOAdapter(config, deps[Api.inference]) + await impl.initialize() + return impl diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py new file mode 100644 index 000000000..60fe3ca2a --- /dev/null +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# config.py +from pydantic import BaseModel +from typing import Any, Dict + +from llama_stack.providers.utils.kvstore.config import ( + KVStoreConfig, + SqliteKVStoreConfig, +) + + +class SQLiteVectorIOConfig(BaseModel): + db_path: str + kvstore: KVStoreConfig + + @classmethod + def sample_run_config(cls, __distro_dir__: str) -> Dict[str, Any]: + return { + "kvstore": SqliteKVStoreConfig.sample_run_config( + __distro_dir__=__distro_dir__, + db_name="sqlite_vec.db", + ) + } diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py new file mode 100644 index 000000000..019d260f8 --- /dev/null +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import sqlite3 +import sqlite_vec +import struct +import logging +import numpy as np +from numpy.typing import NDArray +from typing import List, Optional, Dict, Any + +from llama_stack.apis.vector_dbs import VectorDB +from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO +from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate +from llama_stack.providers.utils.memory.vector_store import EmbeddingIndex, VectorDBWithIndex + +logger = logging.getLogger(__name__) + + +def serialize_vector(vector: List[float]) -> bytes: + """Serialize a list of floats into a compact binary representation.""" + return struct.pack(f"{len(vector)}f", *vector) + + +class SQLiteVecIndex(EmbeddingIndex): + """ + An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec. + Two tables are used: + - A metadata table (chunks_{bank_id}) that holds the chunk JSON. + - A virtual table (vec_chunks_{bank_id}) that holds the serialized vector. + """ + + def __init__(self, dimension: int, connection: sqlite3.Connection, bank_id: str): + self.dimension = dimension + self.connection = connection + self.bank_id = bank_id + self.metadata_table = f"chunks_{bank_id}".replace("-", "_") + self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") + + @classmethod + async def create(cls, dimension: int, connection: sqlite3.Connection, bank_id: str): + instance = cls(dimension, connection, bank_id) + await instance.initialize() + return instance + + async def initialize(self) -> None: + cur = self.connection.cursor() + # Create the table to store chunk metadata. + cur.execute(f""" + CREATE TABLE IF NOT EXISTS {self.metadata_table} ( + id INTEGER PRIMARY KEY, + chunk TEXT + ); + """) + # Create the virtual table for embeddings. + cur.execute(f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table} + USING vec0(embedding FLOAT[{self.dimension}]); + """) + self.connection.commit() + + async def delete(self): + cur = self.connection.cursor() + cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") + cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") + self.connection.commit() + + async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): + """ + Add new chunks along with their embeddings. + For each chunk, we insert its JSON into the metadata table and then insert its + embedding (serialized to raw bytes) into the virtual table using the assigned rowid. + If any insert fails, the transaction is rolled back to maintain consistency. + """ + cur = self.connection.cursor() + try: + # Start transaction + cur.execute("BEGIN TRANSACTION") + for chunk, emb in zip(chunks, embeddings): + # Serialize and insert the chunk metadata. + chunk_json = chunk.model_dump_json() + cur.execute(f"INSERT INTO {self.metadata_table} (chunk) VALUES (?)", (chunk_json,)) + row_id = cur.lastrowid + # Ensure the embedding is a list of floats. + emb_list = emb.tolist() if isinstance(emb, np.ndarray) else list(emb) + emb_blob = serialize_vector(emb_list) + cur.execute(f"INSERT INTO {self.vector_table} (rowid, embedding) VALUES (?, ?)", (row_id, emb_blob)) + # Commit transaction if all inserts succeed + self.connection.commit() + + except sqlite3.Error as e: + self.connection.rollback() # Rollback on failure + print(f"Error inserting into {self.vector_table} - error: {e}") # Log error (Consider using logging module) + + finally: + cur.close() # Ensure cursor is closed + + async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: + """ + Query for the k most similar chunks. We convert the query embedding to a blob and run a SQL query + against the virtual table. The SQL joins the metadata table to recover the chunk JSON. + """ + emb_list = embedding.tolist() if isinstance(embedding, np.ndarray) else list(embedding) + emb_blob = serialize_vector(emb_list) + cur = self.connection.cursor() + query_sql = f""" + SELECT m.id, m.chunk, v.distance + FROM {self.vector_table} AS v + JOIN {self.metadata_table} AS m ON m.id = v.rowid + WHERE v.embedding MATCH ? AND k = ? + ORDER BY v.distance; + """ + cur.execute(query_sql, (emb_blob, k)) + rows = cur.fetchall() + chunks = [] + scores = [] + for _id, chunk_json, distance in rows: + try: + chunk = Chunk.model_validate_json(chunk_json) + except Exception as e: + logger.error(f"Error parsing chunk JSON for id {_id}: {e}") + continue + chunks.append(chunk) + # Mimic the Faiss scoring: score = 1/distance (avoid division by zero) + score = 1.0 / distance if distance != 0 else float("inf") + scores.append(score) + return QueryChunksResponse(chunks=chunks, scores=scores) + + +class SQLiteVecVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + """ + A VectorIO implementation using SQLite + sqlite_vec. + This class handles vector database registration (with metadata stored in a table named `vector_dbs`) + and creates a cache of VectorDBWithIndex instances (each wrapping a SQLiteVecIndex). + """ + + def __init__(self, config, inference_api: Api.inference) -> None: + self.config = config + self.inference_api = inference_api + self.cache: Dict[str, VectorDBWithIndex] = {} + self.connection: Optional[sqlite3.Connection] = None + + async def initialize(self) -> None: + # Open a connection to the SQLite database (the file is specified in the config). + self.connection = sqlite3.connect(self.config.db_path) + self.connection.enable_load_extension(True) + sqlite_vec.load(self.connection) + self.connection.enable_load_extension(False) + cur = self.connection.cursor() + # Create a table to persist vector DB registrations. + cur.execute(""" + CREATE TABLE IF NOT EXISTS vector_dbs ( + id TEXT PRIMARY KEY, + metadata TEXT + ); + """) + self.connection.commit() + # Load any existing vector DB registrations. + cur.execute("SELECT metadata FROM vector_dbs") + rows = cur.fetchall() + for row in rows: + vector_db_data = row[0] + vector_db = VectorDB.model_validate_json(vector_db_data) + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + async def shutdown(self) -> None: + if self.connection: + self.connection.close() + self.connection = None + + async def register_vector_db(self, vector_db: VectorDB) -> None: + if self.connection is None: + raise RuntimeError("SQLite connection not initialized") + cur = self.connection.cursor() + cur.execute( + "INSERT OR REPLACE INTO vector_dbs (id, metadata) VALUES (?, ?)", + (vector_db.identifier, vector_db.model_dump_json()), + ) + self.connection.commit() + index = await SQLiteVecIndex.create(vector_db.embedding_dimension, self.connection, vector_db.identifier) + self.cache[vector_db.identifier] = VectorDBWithIndex(vector_db, index, self.inference_api) + + async def list_vector_dbs(self) -> List[VectorDB]: + return [v.vector_db for v in self.cache.values()] + + async def unregister_vector_db(self, vector_db_id: str) -> None: + if self.connection is None: + raise RuntimeError("SQLite connection not initialized") + if vector_db_id not in self.cache: + logger.warning(f"Vector DB {vector_db_id} not found") + return + await self.cache[vector_db_id].index.delete() + del self.cache[vector_db_id] + cur = self.connection.cursor() + cur.execute("DELETE FROM vector_dbs WHERE id = ?", (vector_db_id,)) + self.connection.commit() + + async def insert_chunks(self, vector_db_id: str, chunks: List[Chunk], ttl_seconds: Optional[int] = None) -> None: + if vector_db_id not in self.cache: + raise ValueError(f"Vector DB {vector_db_id} not found. Found: {list(self.cache.keys())}") + # The VectorDBWithIndex helper is expected to compute embeddings via the inference_api + # and then call our index’s add_chunks. + await self.cache[vector_db_id].insert_chunks(chunks) + + async def query_chunks( + self, vector_db_id: str, query: Any, params: Optional[Dict[str, Any]] = None + ) -> QueryChunksResponse: + if vector_db_id not in self.cache: + raise ValueError(f"Vector DB {vector_db_id} not found") + return await self.cache[vector_db_id].query_chunks(query, params) diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 2d7c02d86..4422baba5 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -54,6 +54,14 @@ def available_providers() -> List[ProviderSpec]: config_class="llama_stack.providers.inline.vector_io.faiss.FaissImplConfig", api_dependencies=[Api.inference], ), + InlineProviderSpec( + api=Api.vector_io, + provider_type="inline::sqlite_vec", + pip_packages=EMBEDDING_DEPS + ["sqlite-vec"], + module="llama_stack.providers.inline.vector_io.sqlite_vec", + config_class="llama_stack.providers.inline.vector_io.sqlite_vec.SQLiteVectorIOConfig", + api_dependencies=[Api.inference], + ), remote_provider_spec( Api.vector_io, AdapterSpec( diff --git a/llama_stack/providers/tests/vector_io/conftest.py b/llama_stack/providers/tests/vector_io/conftest.py index 1feb5af92..3a02ac712 100644 --- a/llama_stack/providers/tests/vector_io/conftest.py +++ b/llama_stack/providers/tests/vector_io/conftest.py @@ -41,6 +41,14 @@ DEFAULT_PROVIDER_COMBINATIONS = [ id="ollama", marks=pytest.mark.ollama, ), + pytest.param( + { + "inference": "ollama", + "vector_io": "sqlite_vec", + }, + id="sqlite_vec", + marks=pytest.mark.ollama, + ), pytest.param( { "inference": "sentence_transformers", diff --git a/llama_stack/providers/tests/vector_io/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py index c8d5fa8cf..54a76141f 100644 --- a/llama_stack/providers/tests/vector_io/fixtures.py +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig +from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig @@ -53,6 +54,22 @@ def vector_io_faiss() -> ProviderFixture: ) +@pytest.fixture(scope="session") +def vector_io_sqlite_vec() -> ProviderFixture: + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".db") + return ProviderFixture( + providers=[ + Provider( + provider_id="sqlite_vec", + provider_type="inline::sqlite_vec", + config=SQLiteVectorIOConfig( + kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), + ).model_dump(), + ) + ], + ) + + @pytest.fixture(scope="session") def vector_io_pgvector() -> ProviderFixture: return ProviderFixture( @@ -111,7 +128,13 @@ def vector_io_chroma() -> ProviderFixture: ) -VECTOR_IO_FIXTURES = ["faiss", "pgvector", "weaviate", "chroma"] +VECTOR_IO_FIXTURES = [ + "faiss", + "pgvector", + "weaviate", + "chroma", + "sqlite_vec", +] @pytest_asyncio.fixture(scope="session") diff --git a/llama_stack/templates/ollama/build.yaml b/llama_stack/templates/ollama/build.yaml index 0fee6808c..48960c5ba 100644 --- a/llama_stack/templates/ollama/build.yaml +++ b/llama_stack/templates/ollama/build.yaml @@ -6,6 +6,7 @@ distribution_spec: - remote::ollama vector_io: - inline::faiss + - inline::sqlite_vec - remote::chromadb - remote::pgvector safety: diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index d14cb3aad..a762e757a 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -17,6 +17,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -49,11 +50,16 @@ def get_distribution_template() -> DistributionTemplate: provider_type="inline::sentence-transformers", config=SentenceTransformersInferenceConfig.sample_run_config(), ) - vector_io_provider = Provider( + vector_io_provider_faiss = Provider( provider_id="faiss", provider_type="inline::faiss", config=FaissImplConfig.sample_run_config(f"distributions/{name}"), ) + vector_io_provider_sqlite = Provider( + provider_id="sqlite_vec", + provider_type="inline::sqlite_vec", + config=SQLiteVectorIOConfig.sample_run_config(f"distributions/{name}"), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", @@ -98,7 +104,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider, embedding_provider], - "vector_io": [vector_io_provider], + "vector_io": [vector_io_provider_faiss, vector_io_provider_sqlite], }, default_models=[inference_model, embedding_model], default_tool_groups=default_tool_groups, @@ -109,7 +115,7 @@ def get_distribution_template() -> DistributionTemplate: inference_provider, embedding_provider, ], - "vector_io": [vector_io_provider], + "vector_io": [vector_io_provider_faiss, vector_io_provider_faiss], "safety": [ Provider( provider_id="llama-guard", diff --git a/llama_stack/templates/ollama/run.yaml b/llama_stack/templates/ollama/run.yaml index 485223675..3a60fe61f 100644 --- a/llama_stack/templates/ollama/run.yaml +++ b/llama_stack/templates/ollama/run.yaml @@ -27,6 +27,14 @@ providers: type: sqlite namespace: null db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db + - provider_id: sqlite_vec + provider_type: inline::sqlite_vec + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/sqlite_vec.db safety: - provider_id: llama-guard provider_type: inline::llama-guard diff --git a/tests/client-sdk/vector_io/test_vector_io.py b/tests/client-sdk/vector_io/test_vector_io.py index 36d3fe2c1..c5be4ab3f 100644 --- a/tests/client-sdk/vector_io/test_vector_io.py +++ b/tests/client-sdk/vector_io/test_vector_io.py @@ -8,6 +8,8 @@ import random import pytest +INLINE_VECTOR_DB_PROVIDERS = ["faiss", "sqlite_vec"] + @pytest.fixture(scope="function") def empty_vector_db_registry(llama_stack_client): @@ -17,26 +19,27 @@ def empty_vector_db_registry(llama_stack_client): @pytest.fixture(scope="function") -def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry): +def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry, provider_id): vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model="all-MiniLM-L6-v2", embedding_dimension=384, - provider_id="faiss", + provider_id=provider_id, ) vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] return vector_dbs -def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry): +@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) +def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): # Register a memory bank first vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model, embedding_dimension=384, - provider_id="faiss", + provider_id=provider_id, ) # Retrieve the memory bank and validate its properties @@ -44,7 +47,7 @@ def test_vector_db_retrieve(llama_stack_client, embedding_model, empty_vector_db assert response is not None assert response.identifier == vector_db_id assert response.embedding_model == embedding_model - assert response.provider_id == "faiss" + assert response.provider_id == provider_id assert response.provider_resource_id == vector_db_id @@ -53,20 +56,22 @@ def test_vector_db_list(llama_stack_client, empty_vector_db_registry): assert len(vector_dbs_after_register) == 0 -def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry): +@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) +def test_vector_db_register(llama_stack_client, embedding_model, empty_vector_db_registry, provider_id): vector_db_id = f"test_vector_db_{random.randint(1000, 9999)}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model, embedding_dimension=384, - provider_id="faiss", + provider_id=provider_id, ) vector_dbs_after_register = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] assert vector_dbs_after_register == [vector_db_id] -def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry): +@pytest.mark.parametrize("provider_id", INLINE_VECTOR_DB_PROVIDERS) +def test_vector_db_unregister(llama_stack_client, single_entry_vector_db_registry, provider_id): vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] assert len(vector_dbs) == 1 From cc700b2f683b18cde328a4fd68da665f6ab661b4 Mon Sep 17 00:00:00 2001 From: Ihar Hrachyshka Date: Thu, 13 Feb 2025 01:03:28 -0500 Subject: [PATCH 17/27] feat: support listing all for `llama stack list-providers` (#1056) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? Support listing all for `llama stack list-providers`. For ease of reading, sort the output rows by type. Before the change. ```  llama stack list-providers usage: llama stack list-providers [-h] {inference,safety,agents,vector_io,datasetio,scoring,eval,post_training,tool_runtime,telemetry} llama stack list-providers: error: the following arguments are required: api ``` After the change. ``` +---------------+----------------------------------+----------------------------------------------------------------------------------+ | API Type | Provider Type | PIP Package Dependencies | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | agents | inline::meta-reference | matplotlib,pillow,pandas,scikit-learn,aiosqlite,psycopg2-binary,redis | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | datasetio | inline::localfs | pandas | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | datasetio | remote::huggingface | datasets | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | eval | inline::meta-reference | | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | inline::meta-reference | accelerate,blobfile,fairscale,torch,torchvision,transformers,zmq,lm-format- | | | | enforcer,sentence-transformers | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | inline::meta-reference-quantized | accelerate,blobfile,fairscale,torch,torchvision,transformers,zmq,lm-format- | | | | enforcer,sentence-transformers,fbgemm-gpu,torchao==0.5.0 | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | inline::sentence-transformers | sentence-transformers | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | inline::vllm | vllm | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::bedrock | boto3 | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::cerebras | cerebras_cloud_sdk | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::databricks | openai | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::fireworks | fireworks-ai | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::groq | groq | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::hf::endpoint | huggingface_hub,aiohttp | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::hf::serverless | huggingface_hub,aiohttp | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::nvidia | openai | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::ollama | ollama,aiohttp | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::runpod | openai | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::sambanova | openai | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::tgi | huggingface_hub,aiohttp | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::together | together | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | inference | remote::vllm | openai | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | post_training | inline::torchtune | torch,torchtune==0.5.0,torchao==0.8.0,numpy | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | safety | inline::code-scanner | codeshield | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | safety | inline::llama-guard | | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | safety | inline::meta-reference | transformers,torch --index-url https://download.pytorch.org/whl/cpu | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | safety | inline::prompt-guard | transformers,torch --index-url https://download.pytorch.org/whl/cpu | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | safety | remote::bedrock | boto3 | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | scoring | inline::basic | | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | scoring | inline::braintrust | autoevals,openai | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | scoring | inline::llm-as-judge | | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | telemetry | inline::meta-reference | opentelemetry-sdk,opentelemetry-exporter-otlp-proto-http | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | tool_runtime | inline::code-interpreter | | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | tool_runtime | inline::rag-runtime | | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | tool_runtime | remote::bing-search | requests | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | tool_runtime | remote::brave-search | requests | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | tool_runtime | remote::model-context-protocol | mcp | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | tool_runtime | remote::tavily-search | requests | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | tool_runtime | remote::wolfram-alpha | requests | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | vector_io | inline::chromadb | blobfile,chardet,pypdf,tqdm,numpy,scikit- | | | | learn,scipy,nltk,sentencepiece,transformers,torch torchvision --index-url | | | | https://download.pytorch.org/whl/cpu,sentence-transformers --no-deps,chromadb | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | vector_io | inline::faiss | blobfile,chardet,pypdf,tqdm,numpy,scikit- | | | | learn,scipy,nltk,sentencepiece,transformers,torch torchvision --index-url | | | | https://download.pytorch.org/whl/cpu,sentence-transformers --no-deps,faiss-cpu | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | vector_io | inline::meta-reference | blobfile,chardet,pypdf,tqdm,numpy,scikit- | | | | learn,scipy,nltk,sentencepiece,transformers,torch torchvision --index-url | | | | https://download.pytorch.org/whl/cpu,sentence-transformers --no-deps,faiss-cpu | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | vector_io | remote::chromadb | blobfile,chardet,pypdf,tqdm,numpy,scikit- | | | | learn,scipy,nltk,sentencepiece,transformers,torch torchvision --index-url | | | | https://download.pytorch.org/whl/cpu,sentence-transformers --no-deps,chromadb- | | | | client | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | vector_io | remote::pgvector | blobfile,chardet,pypdf,tqdm,numpy,scikit- | | | | learn,scipy,nltk,sentencepiece,transformers,torch torchvision --index-url | | | | https://download.pytorch.org/whl/cpu,sentence-transformers --no- | | | | deps,psycopg2-binary | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | vector_io | remote::qdrant | blobfile,chardet,pypdf,tqdm,numpy,scikit- | | | | learn,scipy,nltk,sentencepiece,transformers,torch torchvision --index-url | | | | https://download.pytorch.org/whl/cpu,sentence-transformers --no-deps,qdrant- | | | | client | +---------------+----------------------------------+----------------------------------------------------------------------------------+ | vector_io | remote::weaviate | blobfile,chardet,pypdf,tqdm,numpy,scikit- | | | | learn,scipy,nltk,sentencepiece,transformers,torch torchvision --index-url | | | | https://download.pytorch.org/whl/cpu,sentence-transformers --no-deps,weaviate- | | | | client | +---------------+----------------------------------+----------------------------------------------------------------------------------+ ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Manually. [//]: # (## Documentation) Signed-off-by: Ihar Hrachyshka --- llama_stack/cli/stack/list_providers.py | 26 +++++++++++++++++++------ llama_stack/cli/table.py | 7 ++++++- 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/llama_stack/cli/stack/list_providers.py b/llama_stack/cli/stack/list_providers.py index bd152c980..bfe11aa2c 100644 --- a/llama_stack/cli/stack/list_providers.py +++ b/llama_stack/cli/stack/list_providers.py @@ -21,15 +21,19 @@ class StackListProviders(Subcommand): self._add_arguments() self.parser.set_defaults(func=self._run_providers_list_cmd) - def _add_arguments(self): + @property + def providable_apis(self): from llama_stack.distribution.distribution import providable_apis - api_values = [api.value for api in providable_apis()] + return [api.value for api in providable_apis()] + + def _add_arguments(self): self.parser.add_argument( "api", type=str, - choices=api_values, - help="API to list providers for (one of: {})".format(api_values), + choices=self.providable_apis, + nargs="?", + help="API to list providers for. List all if not specified.", ) def _run_providers_list_cmd(self, args: argparse.Namespace) -> None: @@ -37,20 +41,29 @@ class StackListProviders(Subcommand): from llama_stack.distribution.distribution import Api, get_provider_registry all_providers = get_provider_registry() - providers_for_api = all_providers[Api(args.api)] + if args.api: + providers = [(args.api, all_providers[Api(args.api)])] + else: + providers = [(k.value, prov) for k, prov in all_providers.items()] + + providers = [p for api, p in providers if api in self.providable_apis] # eventually, this should query a registry at llama.meta.com/llamastack/distributions headers = [ + "API Type", "Provider Type", "PIP Package Dependencies", ] rows = [] - for spec in providers_for_api.values(): + + specs = [spec for p in providers for spec in p.values()] + for spec in specs: if spec.is_sample: continue rows.append( [ + spec.api.value, spec.provider_type, ",".join(spec.pip_packages), ] @@ -59,4 +72,5 @@ class StackListProviders(Subcommand): rows, headers, separate_rows=True, + sort_by=(0, 1), ) diff --git a/llama_stack/cli/table.py b/llama_stack/cli/table.py index 50f54852b..847719f81 100644 --- a/llama_stack/cli/table.py +++ b/llama_stack/cli/table.py @@ -6,6 +6,7 @@ import re import textwrap +from typing import Iterable from termcolor import cprint @@ -39,11 +40,15 @@ def format_row(row, col_widths): return "\n".join(lines) -def print_table(rows, headers=None, separate_rows: bool = False): +def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()): def itemlen(item): return max([len(line) for line in strip_ansi_colors(item).split("\n")]) rows = [[x or "" for x in row] for row in rows] + + if sort_by: + rows.sort(key=lambda x: tuple(x[i] for i in sort_by)) + if not headers: col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)] else: From 8c01b7f05a16006e59b1ee0c3b102493f49b8cff Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 13 Feb 2025 10:57:30 -0500 Subject: [PATCH 18/27] docs: Mention convential commits format in CONTRIBUTING.md (#1075) # What does this PR do? This adds a note to ensure pull requests follow the conventional commits format, along with a link to that format, in CONTRIBUTING.md. One of the pull-request checks enforces PR titles that match this format, so it's good to be upfront about this expectation before a new developer opens a PR. Signed-off-by: Ben Browning --- CONTRIBUTING.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8028c194e..6dc08b5c0 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -40,6 +40,7 @@ If you need help or guidance, comment on the issue. Issues that are extra friend 3. Ensure the test suite passes. 4. Make sure your code lints using `pre-commit`. 5. If you haven't already, complete the Contributor License Agreement ("CLA"). +6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/). ## Contributor License Agreement ("CLA") In order to accept your pull request, we need you to submit a CLA. You only need From dd1a366347da6d0d719721a819cced947beba681 Mon Sep 17 00:00:00 2001 From: Ben Browning Date: Thu, 13 Feb 2025 11:00:00 -0500 Subject: [PATCH 19/27] fix: logprobs support in remote-vllm provider (#1074) # What does this PR do? The remote-vllm provider was not passing logprobs options from CompletionRequest or ChatCompletionRequests through to the OpenAI client parameters. I manually verified this, as well as observed this provider failing `TestInference::test_completion_logprobs`. This was filed as issue #1073. This fixes that by passing the `logprobs.top_k` value through to the parameters we pass into the OpenAI client. Additionally, this fixes a bug in `test_text_inference.py` where it mistakenly assumed chunk.delta were of type `ContentDelta` for completion requests. The deltas are of type `ContentDelta` for chat completion requests, but for basic completion requests the deltas are of type string. This test was likely failing for other providers that did properly support logprobs because of this latter issue in the test, which was hit while fixing the above issue with the remote-vllm provider. (Closes #1073) ## Test Plan First, you need a vllm running. I ran one locally like this: ``` vllm serve meta-llama/Llama-3.2-3B-Instruct --port 8001 --enable-auto-tool-choice --tool-call-parser llama3_json ``` Next, run test_text_inference.py against this vllm using the remote vllm provider like this: ``` VLLM_URL="http://localhost:8001/v1" python -m pytest -s -v llama_stack/providers/tests/inference/test_text_inference.py --providers "inference=vllm_remote" ``` Before my change, the test failed with this error: ``` llama_stack/providers/tests/inference/test_text_inference.py:155: in test_completion_logprobs assert 1 <= len(response.logprobs) <= 5 E TypeError: object of type 'NoneType' has no len() ``` After my change, the test passes. [//]: # (## Documentation) Signed-off-by: Ben Browning --- llama_stack/providers/remote/inference/vllm/vllm.py | 3 +++ llama_stack/providers/tests/inference/test_text_inference.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 02594891b..3574768b5 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -345,6 +345,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unknown response format {fmt.type}") + if request.logprobs and request.logprobs.top_k: + input_dict["logprobs"] = request.logprobs.top_k + return { "model": request.model, **input_dict, diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99f968cbc..6a7259123 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -175,7 +175,7 @@ class TestInference: 1 <= len(chunks) <= 6 ) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason for chunk in chunks: - if chunk.delta.type == "text" and chunk.delta.text: # if there's a token, we expect logprobs + if chunk.delta: # if there's a token, we expect logprobs assert chunk.logprobs, "Logprobs should not be empty" assert all(len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs) else: # no token, no logprobs From 418645696ac7ddb3aa9c68959801f4cf41cc0b8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 13 Feb 2025 17:07:59 +0100 Subject: [PATCH 20/27] fix: improve signal handling and update dependencies (#1044) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? This commit enhances the signal handling mechanism in the server by improving the `handle_signal` (previously handle_sigint) function. It now properly retrieves the signal name, ensuring clearer logging when a termination signal is received. Additionally, it cancels all running tasks and waits for their completion before stopping the event loop, allowing for a more graceful shutdown. Support for handling SIGTERM has also been added alongside SIGINT. Before the changes, handle_sigint used asyncio.run(run_shutdown()). However, asyncio.run() is meant to start a new event loop, and calling it inside an existing one (like when running Uvicorn) raises an error. The fix replaces asyncio.run(run_shutdown()) with an async function scheduled on the existing loop using loop.create_task(shutdown()). This ensures that the shutdown coroutine runs within the current event loop instead of trying to create a new one. Furthermore, this commit updates the project dependencies. `fastapi` and `uvicorn` have been added to the development dependencies in `pyproject.toml` and `uv.lock`, ensuring that the necessary packages are available for development and execution. Closes: https://github.com/meta-llama/llama-stack/issues/1043 Signed-off-by: Sébastien Han [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan Run a server and send SIGINT: ``` INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct" python -m llama_stack.distribution.server.server --yaml-config ./llama_stack/templates/ollama/run.yaml Using config file: llama_stack/templates/ollama/run.yaml Run configuration: apis: - agents - datasetio - eval - inference - safety - scoring - telemetry - tool_runtime - vector_io container_image: null datasets: [] eval_tasks: [] image_name: ollama metadata_store: db_path: /Users/leseb/.llama/distributions/ollama/registry.db namespace: null type: sqlite models: - metadata: {} model_id: meta-llama/Llama-3.2-3B-Instruct model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType - llm provider_id: ollama provider_model_id: null - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 model_type: !!python/object/apply:llama_stack.apis.models.models.ModelType - embedding provider_id: sentence-transformers provider_model_id: null providers: agents: - config: persistence_store: db_path: /Users/leseb/.llama/distributions/ollama/agents_store.db namespace: null type: sqlite provider_id: meta-reference provider_type: inline::meta-reference datasetio: - config: {} provider_id: huggingface provider_type: remote::huggingface - config: {} provider_id: localfs provider_type: inline::localfs eval: - config: {} provider_id: meta-reference provider_type: inline::meta-reference inference: - config: url: http://localhost:11434 provider_id: ollama provider_type: remote::ollama - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers safety: - config: {} provider_id: llama-guard provider_type: inline::llama-guard scoring: - config: {} provider_id: basic provider_type: inline::basic - config: {} provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: openai_api_key: '********' provider_id: braintrust provider_type: inline::braintrust telemetry: - config: service_name: llama-stack sinks: console,sqlite sqlite_db_path: /Users/leseb/.llama/distributions/ollama/trace_store.db provider_id: meta-reference provider_type: inline::meta-reference tool_runtime: - config: api_key: '********' max_results: 3 provider_id: brave-search provider_type: remote::brave-search - config: api_key: '********' max_results: 3 provider_id: tavily-search provider_type: remote::tavily-search - config: {} provider_id: code-interpreter provider_type: inline::code-interpreter - config: {} provider_id: rag-runtime provider_type: inline::rag-runtime vector_io: - config: kvstore: db_path: /Users/leseb/.llama/distributions/ollama/faiss_store.db namespace: null type: sqlite provider_id: faiss provider_type: inline::faiss scoring_fns: [] server: port: 8321 tls_certfile: null tls_keyfile: null shields: [] tool_groups: - args: null mcp_endpoint: null provider_id: tavily-search toolgroup_id: builtin::websearch - args: null mcp_endpoint: null provider_id: rag-runtime toolgroup_id: builtin::rag - args: null mcp_endpoint: null provider_id: code-interpreter toolgroup_id: builtin::code_interpreter vector_dbs: [] version: '2' INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:213: Resolved 31 providers INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-inference => ollama INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-inference => sentence-transformers INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: models => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inference => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-vector_io => faiss INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-safety => llama-guard INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: shields => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: safety => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: vector_dbs => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: vector_io => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => brave-search INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => tavily-search INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => code-interpreter INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-tool_runtime => rag-runtime INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: tool_groups => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: tool_runtime => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: agents => meta-reference INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-datasetio => huggingface INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-datasetio => localfs INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: datasets => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: datasetio => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: telemetry => meta-reference INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-scoring => basic INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-scoring => llm-as-judge INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-scoring => braintrust INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: scoring_functions => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: scoring => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inner-eval => meta-reference INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: eval_tasks => __routing_table__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: eval => __autorouted__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:215: inspect => __builtin__ INFO 2025-02-12 10:21:03,540 llama_stack.distribution.resolver:216: INFO 2025-02-12 10:21:03,723 llama_stack.providers.remote.inference.ollama.ollama:148: checking connectivity to Ollama at `http://localhost:11434`... INFO 2025-02-12 10:21:03,734 httpx:1740: HTTP Request: GET http://localhost:11434/api/ps "HTTP/1.1 200 OK" INFO 2025-02-12 10:21:03,843 faiss.loader:148: Loading faiss. INFO 2025-02-12 10:21:03,865 faiss.loader:150: Successfully loaded faiss. INFO 2025-02-12 10:21:03,868 faiss:173: Failed to load GPU Faiss: name 'GpuIndexIVFFlat' is not defined. Will not load constructor refs for GPU indexes. Warning: `bwrap` is not available. Code interpreter tool will not work correctly. INFO 2025-02-12 10:21:04,315 datasets:54: PyTorch version 2.6.0 available. INFO 2025-02-12 10:21:04,556 httpx:1740: HTTP Request: GET http://localhost:11434/api/ps "HTTP/1.1 200 OK" INFO 2025-02-12 10:21:04,557 llama_stack.providers.utils.inference.embedding_mixin:42: Loading sentence transformer for all-MiniLM-L6-v2... INFO 2025-02-12 10:21:07,202 sentence_transformers.SentenceTransformer:210: Use pytorch device_name: mps INFO 2025-02-12 10:21:07,202 sentence_transformers.SentenceTransformer:218: Load pretrained SentenceTransformer: all-MiniLM-L6-v2 INFO 2025-02-12 10:21:09,500 llama_stack.distribution.stack:102: Models: all-MiniLM-L6-v2 served by sentence-transformers INFO 2025-02-12 10:21:09,500 llama_stack.distribution.stack:102: Models: meta-llama/Llama-3.2-3B-Instruct served by ollama INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::equality served by basic INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::regex_parser_multiple_choice_answer served by basic INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: basic::subset_of served by basic INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-correctness served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-relevancy served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::answer-similarity served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-entity-recall served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-precision served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-recall served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::context-relevancy served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::factuality served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: braintrust::faithfulness served by braintrust INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: llm-as-judge::405b-simpleqa served by llm-as-judge INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Scoring_fns: llm-as-judge::base served by llm-as-judge INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::code_interpreter served by code-interpreter INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::rag served by rag-runtime INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:102: Tool_groups: builtin::websearch served by tavily-search INFO 2025-02-12 10:21:09,501 llama_stack.distribution.stack:106: Serving API eval POST /v1/eval/tasks/{task_id}/evaluations DELETE /v1/eval/tasks/{task_id}/jobs/{job_id} GET /v1/eval/tasks/{task_id}/jobs/{job_id}/result GET /v1/eval/tasks/{task_id}/jobs/{job_id} POST /v1/eval/tasks/{task_id}/jobs Serving API agents POST /v1/agents POST /v1/agents/{agent_id}/session POST /v1/agents/{agent_id}/session/{session_id}/turn DELETE /v1/agents/{agent_id} DELETE /v1/agents/{agent_id}/session/{session_id} GET /v1/agents/{agent_id}/session/{session_id} GET /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id} GET /v1/agents/{agent_id}/session/{session_id}/turn/{turn_id} Serving API scoring_functions GET /v1/scoring-functions/{scoring_fn_id} GET /v1/scoring-functions POST /v1/scoring-functions Serving API safety POST /v1/safety/run-shield Serving API inspect GET /v1/health GET /v1/inspect/providers GET /v1/inspect/routes GET /v1/version Serving API tool_runtime POST /v1/tool-runtime/invoke GET /v1/tool-runtime/list-tools POST /v1/tool-runtime/rag-tool/insert POST /v1/tool-runtime/rag-tool/query Serving API datasetio POST /v1/datasetio/rows GET /v1/datasetio/rows Serving API shields GET /v1/shields/{identifier} GET /v1/shields POST /v1/shields Serving API eval_tasks GET /v1/eval-tasks/{eval_task_id} GET /v1/eval-tasks POST /v1/eval-tasks Serving API models GET /v1/models/{model_id} GET /v1/models POST /v1/models DELETE /v1/models/{model_id} Serving API datasets GET /v1/datasets/{dataset_id} GET /v1/datasets POST /v1/datasets DELETE /v1/datasets/{dataset_id} Serving API vector_io POST /v1/vector-io/insert POST /v1/vector-io/query Serving API inference POST /v1/inference/chat-completion POST /v1/inference/completion POST /v1/inference/embeddings Serving API tool_groups GET /v1/tools/{tool_name} GET /v1/toolgroups/{toolgroup_id} GET /v1/toolgroups GET /v1/tools POST /v1/toolgroups DELETE /v1/toolgroups/{toolgroup_id} Serving API vector_dbs GET /v1/vector-dbs/{vector_db_id} GET /v1/vector-dbs POST /v1/vector-dbs DELETE /v1/vector-dbs/{vector_db_id} Serving API scoring POST /v1/scoring/score POST /v1/scoring/score-batch Serving API telemetry GET /v1/telemetry/traces/{trace_id}/spans/{span_id} GET /v1/telemetry/spans/{span_id}/tree GET /v1/telemetry/traces/{trace_id} POST /v1/telemetry/events GET /v1/telemetry/spans GET /v1/telemetry/traces POST /v1/telemetry/spans/export Listening on ['::', '0.0.0.0']:5001 INFO: Started server process [65372] INFO: Waiting for application startup. INFO: ASGI 'lifespan' protocol appears unsupported. INFO: Application startup complete. INFO: Uvicorn running on http://['::', '0.0.0.0']:5001 (Press CTRL+C to quit) ^CINFO: Shutting down INFO: Finished server process [65372] Received signal SIGINT (2). Exiting gracefully... INFO 2025-02-12 10:21:11,215 __main__:151: Shutting down ModelsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down InferenceRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ShieldsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down SafetyRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down VectorDBsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down VectorIORouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ToolGroupsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ToolRuntimeRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down MetaReferenceAgentsImpl INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DatasetsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DatasetIORouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down TelemetryAdapter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ScoringFunctionsRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down ScoringRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down EvalTasksRoutingTable INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down EvalRouter INFO 2025-02-12 10:21:11,216 __main__:151: Shutting down DistributionInspectImpl ``` [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) Signed-off-by: Sébastien Han --- llama_stack/distribution/inspect.py | 3 + .../distribution/routers/routing_tables.py | 3 + llama_stack/distribution/server/server.py | 77 ++++++++++++++++--- .../inline/agents/meta_reference/agents.py | 3 + pyproject.toml | 2 + uv.lock | 18 +++++ 6 files changed, 94 insertions(+), 12 deletions(-) diff --git a/llama_stack/distribution/inspect.py b/llama_stack/distribution/inspect.py index b7ee4a219..fddb62570 100644 --- a/llama_stack/distribution/inspect.py +++ b/llama_stack/distribution/inspect.py @@ -82,3 +82,6 @@ class DistributionInspectImpl(Inspect): async def version(self) -> VersionInfo: return VersionInfo(version=version("llama-stack")) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 68fafd8ee..009775ca5 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -537,3 +537,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups): for tool in tools: await self.unregister_object(tool) await self.unregister_object(tool_group) + + async def shutdown(self) -> None: + pass diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index d2c32de11..bb735268b 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -7,6 +7,7 @@ import argparse import asyncio import functools +import logging import inspect import json import os @@ -52,6 +53,9 @@ from .endpoints import get_all_api_endpoints REPO_ROOT = Path(__file__).parent.parent.parent.parent +logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s") +logger = logging.getLogger(__name__) + def warn_with_traceback(message, category, filename, lineno, file=None, line=None): log = file if hasattr(file, "write") else sys.stderr @@ -112,21 +116,69 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) -def handle_sigint(app, *args, **kwargs): - print("SIGINT or CTRL-C detected. Exiting gracefully...") +def handle_signal(app, signum, _) -> None: + """ + Handle incoming signals and initiate a graceful shutdown of the application. - async def run_shutdown(): - for impl in app.__llama_stack_impls__.values(): - print(f"Shutting down {impl}") - await impl.shutdown() + This function is intended to be used as a signal handler for various signals + (e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message + indicating the received signal and initiate a shutdown process. - asyncio.run(run_shutdown()) + Args: + app: The application instance containing implementations to be shut down. + signum (int): The signal number received. + frame: The current stack frame (not used in this function). - loop = asyncio.get_event_loop() - for task in asyncio.all_tasks(loop): - task.cancel() + The shutdown process involves: + - Shutting down all implementations registered in the application. + - Gathering all running asyncio tasks. + - Cancelling all gathered tasks. + - Waiting for all tasks to finish. + - Stopping the event loop. - loop.stop() + Note: + This function schedules the shutdown process as an asyncio task and does + not block the current execution. + """ + signame = signal.Signals(signum).name + print(f"Received signal {signame} ({signum}). Exiting gracefully...") + + async def shutdown(): + try: + # Gracefully shut down implementations + for impl in app.__llama_stack_impls__.values(): + impl_name = impl.__class__.__name__ + logger.info("Shutting down %s", impl_name) + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning("No shutdown method for %s", impl_name) + except asyncio.TimeoutError: + logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True) + except Exception as e: + logger.exception("Failed to shutdown %s: %s", impl_name, {e}) + + # Gather all running tasks + loop = asyncio.get_running_loop() + tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()] + + # Cancel all tasks + for task in tasks: + task.cancel() + + # Wait for all tasks to finish + try: + await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10) + except asyncio.TimeoutError: + logger.exception("Timeout while waiting for tasks to finish") + except asyncio.CancelledError: + pass + finally: + loop.stop() + + loop = asyncio.get_running_loop() + loop.create_task(shutdown()) @asynccontextmanager @@ -386,7 +438,8 @@ def main(): print("") app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - signal.signal(signal.SIGINT, functools.partial(handle_sigint, app)) + signal.signal(signal.SIGINT, functools.partial(handle_signal, app)) + signal.signal(signal.SIGTERM, functools.partial(handle_signal, app)) app.__llama_stack_impls__ = impls diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index fe4ccd1a3..e3c18d112 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -212,3 +212,6 @@ class MetaReferenceAgentsImpl(Agents): async def delete_agent(self, agent_id: str) -> None: await self.persistence_store.delete(f"agent:{agent_id}") + + async def shutdown(self) -> None: + pass diff --git a/pyproject.toml b/pyproject.toml index 5e9cb75e2..2f40ceac9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ dev = [ "types-requests", "types-setuptools", "pre-commit", + "uvicorn", + "fastapi", ] docs = [ "sphinx-autobuild", diff --git a/uv.lock b/uv.lock index 087396eea..97ae52124 100644 --- a/uv.lock +++ b/uv.lock @@ -431,6 +431,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, ] +[[package]] +name = "fastapi" +version = "0.115.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, + { name = "starlette" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a2/b2/5a5dc4affdb6661dea100324e19a7721d5dc524b464fe8e366c093fd7d87/fastapi-0.115.8.tar.gz", hash = "sha256:0ce9111231720190473e222cdf0f07f7206ad7e53ea02beb1d2dc36e2f0741e9", size = 295403 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8f/7d/2d6ce181d7a5f51dedb8c06206cbf0ec026a99bf145edd309f9e17c3282f/fastapi-0.115.8-py3-none-any.whl", hash = "sha256:753a96dd7e036b34eeef8babdfcfe3f28ff79648f86551eb36bfc1b0bf4a8cbf", size = 94814 }, +] + [[package]] name = "fastjsonschema" version = "2.21.1" @@ -724,6 +738,7 @@ dependencies = [ [package.optional-dependencies] dev = [ { name = "black" }, + { name = "fastapi" }, { name = "nbval" }, { name = "pre-commit" }, { name = "pytest" }, @@ -731,6 +746,7 @@ dev = [ { name = "ruff" }, { name = "types-requests" }, { name = "types-setuptools" }, + { name = "uvicorn" }, ] docs = [ { name = "myst-parser" }, @@ -748,6 +764,7 @@ docs = [ requires-dist = [ { name = "black", marker = "extra == 'dev'" }, { name = "blobfile" }, + { name = "fastapi", marker = "extra == 'dev'" }, { name = "fire" }, { name = "httpx" }, { name = "huggingface-hub" }, @@ -776,6 +793,7 @@ requires-dist = [ { name = "termcolor" }, { name = "types-requests", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, + { name = "uvicorn", marker = "extra == 'dev'" }, ] [[package]] From 47fccf0d03814589230af1a32fd2f683d4f74819 Mon Sep 17 00:00:00 2001 From: Reid <61492567+reidliu41@users.noreply.github.com> Date: Fri, 14 Feb 2025 00:33:11 +0800 Subject: [PATCH 21/27] style: update model id in model list title (#1072) # What does this PR do? [Provide a short summary of what this PR does and why. Link to relevant issues if applicable.] Since the subcommands used `MODEL_ID`, it would be better to use it in `model list` and make it easy to find it. ``` $ llama model verify-download --help usage: llama model verify-download [-h] --model-id MODEL_ID << $ llama model describe --help usage: llama model describe [-h] -m MODEL_ID << $ llama download --help --model-id MODEL_ID See `llama model list` or `llama model list --show-all` for the list of available models before: $ llama model list +-----------------------------------------+-----------------------------------------------------+----------------+ | Model Descriptor | Hugging Face Repo | Context Length | +-----------------------------------------+-----------------------------------------------------+----------------+ after: $ llama model list +-----------------------------------------+-----------------------------------------------------+----------------+ | Model Descriptor | Model ID | Context Length | +-----------------------------------------+-----------------------------------------------------+----------------+ | Llama3.1-8B | meta-llama/Llama-3.1-8B | 128K | +-----------------------------------------+-----------------------------------------------------+----------------+ ``` [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) Signed-off-by: reidliu Co-authored-by: reidliu --- llama_stack/cli/model/list.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/cli/model/list.py b/llama_stack/cli/model/list.py index 6d296e75e..9b5ebb1a5 100644 --- a/llama_stack/cli/model/list.py +++ b/llama_stack/cli/model/list.py @@ -38,7 +38,7 @@ class ModelList(Subcommand): headers = [ "Model Descriptor", - "Hugging Face Repo", + "Model ID", "Context Length", ] From 2fa9e3c941d4b1d3183f45d3fa883637b1aa4110 Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 13 Feb 2025 08:46:43 -0800 Subject: [PATCH 22/27] fix: make backslash work in GET /models/{model_id:path} (#1068) --- docs/openapi_generator/pyopenapi/generator.py | 5 ++++- .../openapi_generator/pyopenapi/operations.py | 2 ++ llama_stack/apis/agents/agents.py | 19 +++++++++++-------- llama_stack/apis/datasets/datasets.py | 4 ++-- llama_stack/apis/models/models.py | 4 ++-- .../scoring_functions/scoring_functions.py | 2 +- llama_stack/apis/shields/shields.py | 2 +- llama_stack/apis/telemetry/telemetry.py | 8 ++++---- llama_stack/apis/tools/tools.py | 6 +++--- llama_stack/apis/vector_dbs/vector_dbs.py | 4 ++-- 10 files changed, 32 insertions(+), 24 deletions(-) diff --git a/docs/openapi_generator/pyopenapi/generator.py b/docs/openapi_generator/pyopenapi/generator.py index f0d30a0e6..a0385cae0 100644 --- a/docs/openapi_generator/pyopenapi/generator.py +++ b/docs/openapi_generator/pyopenapi/generator.py @@ -644,7 +644,9 @@ class Generator: else: callbacks = None - description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description])) + description = "\n".join( + filter(None, [doc_string.short_description, doc_string.long_description]) + ) return Operation( tags=[op.defining_class.__name__], summary=None, @@ -681,6 +683,7 @@ class Generator: raise NotImplementedError(f"unknown HTTP method: {op.http_method}") route = op.get_route() + route = route.replace(":path", "") print(f"route: {route}") if route in paths: paths[route].update(pathItem) diff --git a/docs/openapi_generator/pyopenapi/operations.py b/docs/openapi_generator/pyopenapi/operations.py index abeb16936..bf4d35c87 100644 --- a/docs/openapi_generator/pyopenapi/operations.py +++ b/docs/openapi_generator/pyopenapi/operations.py @@ -130,6 +130,8 @@ class _FormatParameterExtractor: def _get_route_parameters(route: str) -> List[str]: extractor = _FormatParameterExtractor() + # Replace all occurrences of ":path" with empty string + route = route.replace(":path", "") route.format_map(extractor) return extractor.keys diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 785248633..b20145be9 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -29,11 +29,11 @@ from llama_stack.apis.inference import ( SamplingParams, ToolCall, ToolChoice, + ToolConfig, ToolPromptFormat, ToolResponse, ToolResponseMessage, UserMessage, - ToolConfig, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -318,7 +318,7 @@ class Agents(Protocol): agent_config: AgentConfig, ) -> AgentCreateResponse: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") + @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}/turn", method="POST") async def create_agent_turn( self, agent_id: str, @@ -335,7 +335,10 @@ class Agents(Protocol): tool_config: Optional[ToolConfig] = None, ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET") + @webmethod( + route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}", + method="GET", + ) async def get_agents_turn( self, agent_id: str, @@ -344,7 +347,7 @@ class Agents(Protocol): ) -> Turn: ... @webmethod( - route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}", + route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}/step/{step_id:path}", method="GET", ) async def get_agents_step( @@ -355,14 +358,14 @@ class Agents(Protocol): step_id: str, ) -> AgentStepResponse: ... - @webmethod(route="/agents/{agent_id}/session", method="POST") + @webmethod(route="/agents/{agent_id:path}/session", method="POST") async def create_agent_session( self, agent_id: str, session_name: str, ) -> AgentSessionCreateResponse: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET") + @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="GET") async def get_agents_session( self, session_id: str, @@ -370,14 +373,14 @@ class Agents(Protocol): turn_ids: Optional[List[str]] = None, ) -> Session: ... - @webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE") + @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="DELETE") async def delete_agents_session( self, session_id: str, agent_id: str, ) -> None: ... - @webmethod(route="/agents/{agent_id}", method="DELETE") + @webmethod(route="/agents/{agent_id:path}", method="DELETE") async def delete_agent( self, agent_id: str, diff --git a/llama_stack/apis/datasets/datasets.py b/llama_stack/apis/datasets/datasets.py index 5ad5bdcdb..5e2b38697 100644 --- a/llama_stack/apis/datasets/datasets.py +++ b/llama_stack/apis/datasets/datasets.py @@ -58,7 +58,7 @@ class Datasets(Protocol): metadata: Optional[Dict[str, Any]] = None, ) -> None: ... - @webmethod(route="/datasets/{dataset_id}", method="GET") + @webmethod(route="/datasets/{dataset_id:path}", method="GET") async def get_dataset( self, dataset_id: str, @@ -67,7 +67,7 @@ class Datasets(Protocol): @webmethod(route="/datasets", method="GET") async def list_datasets(self) -> ListDatasetsResponse: ... - @webmethod(route="/datasets/{dataset_id}", method="DELETE") + @webmethod(route="/datasets/{dataset_id:path}", method="DELETE") async def unregister_dataset( self, dataset_id: str, diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 3361c2836..7e6d9854f 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -62,7 +62,7 @@ class Models(Protocol): @webmethod(route="/models", method="GET") async def list_models(self) -> ListModelsResponse: ... - @webmethod(route="/models/{model_id}", method="GET") + @webmethod(route="/models/{model_id:path}", method="GET") async def get_model( self, model_id: str, @@ -78,7 +78,7 @@ class Models(Protocol): model_type: Optional[ModelType] = None, ) -> Model: ... - @webmethod(route="/models/{model_id}", method="DELETE") + @webmethod(route="/models/{model_id:path}", method="DELETE") async def unregister_model( self, model_id: str, diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 325979583..3fa40ffbf 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -134,7 +134,7 @@ class ScoringFunctions(Protocol): @webmethod(route="/scoring-functions", method="GET") async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ... - @webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET") + @webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET") async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ... @webmethod(route="/scoring-functions", method="POST") diff --git a/llama_stack/apis/shields/shields.py b/llama_stack/apis/shields/shields.py index 3dd685b14..ae316ee53 100644 --- a/llama_stack/apis/shields/shields.py +++ b/llama_stack/apis/shields/shields.py @@ -48,7 +48,7 @@ class Shields(Protocol): @webmethod(route="/shields", method="GET") async def list_shields(self) -> ListShieldsResponse: ... - @webmethod(route="/shields/{identifier}", method="GET") + @webmethod(route="/shields/{identifier:path}", method="GET") async def get_shield(self, identifier: str) -> Optional[Shield]: ... @webmethod(route="/shields", method="POST") diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 6272cc40b..5622aaeac 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -13,8 +13,8 @@ from typing import ( Literal, Optional, Protocol, - Union, runtime_checkable, + Union, ) from llama_models.llama3.api.datatypes import Primitive @@ -224,13 +224,13 @@ class Telemetry(Protocol): order_by: Optional[List[str]] = None, ) -> QueryTracesResponse: ... - @webmethod(route="/telemetry/traces/{trace_id}", method="GET") + @webmethod(route="/telemetry/traces/{trace_id:path}", method="GET") async def get_trace(self, trace_id: str) -> Trace: ... - @webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET") + @webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET") async def get_span(self, trace_id: str, span_id: str) -> Span: ... - @webmethod(route="/telemetry/spans/{span_id}/tree", method="GET") + @webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET") async def get_span_tree( self, span_id: str, diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index d6d806c53..a8e946b08 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -101,7 +101,7 @@ class ToolGroups(Protocol): """Register a tool group""" ... - @webmethod(route="/toolgroups/{toolgroup_id}", method="GET") + @webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET") async def get_tool_group( self, toolgroup_id: str, @@ -117,13 +117,13 @@ class ToolGroups(Protocol): """List tools with optional tool group""" ... - @webmethod(route="/tools/{tool_name}", method="GET") + @webmethod(route="/tools/{tool_name:path}", method="GET") async def get_tool( self, tool_name: str, ) -> Tool: ... - @webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE") + @webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE") async def unregister_toolgroup( self, toolgroup_id: str, diff --git a/llama_stack/apis/vector_dbs/vector_dbs.py b/llama_stack/apis/vector_dbs/vector_dbs.py index 4b782e2d5..1da2c128c 100644 --- a/llama_stack/apis/vector_dbs/vector_dbs.py +++ b/llama_stack/apis/vector_dbs/vector_dbs.py @@ -46,7 +46,7 @@ class VectorDBs(Protocol): @webmethod(route="/vector-dbs", method="GET") async def list_vector_dbs(self) -> ListVectorDBsResponse: ... - @webmethod(route="/vector-dbs/{vector_db_id}", method="GET") + @webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET") async def get_vector_db( self, vector_db_id: str, @@ -62,5 +62,5 @@ class VectorDBs(Protocol): provider_vector_db_id: Optional[str] = None, ) -> VectorDB: ... - @webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE") + @webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE") async def unregister_vector_db(self, vector_db_id: str) -> None: ... From f9ca4419744ad60b400284ba2dfa74aeb0a13fa0 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 13 Feb 2025 12:14:57 -0500 Subject: [PATCH 23/27] chore: Link to Groq docs in the warning message for preview model (#1060) This should be `llama-3.2-3b` instead of `llama-3.2-3b-instruct`. --- llama_stack/providers/remote/inference/groq/groq.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 4e6cc2d6b..9b3c1abbf 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -108,6 +108,7 @@ class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderD "Groq only contains a preview version for llama-3.2-3b-instruct. " "Preview models aren't recommended for production use. " "They can be discontinued on short notice." + "More details: https://console.groq.com/docs/models" ) request = convert_chat_completion_request( From 1527c301076d1e062e652b00c714a5700157e41e Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Thu, 13 Feb 2025 10:04:43 -0800 Subject: [PATCH 24/27] fix: remove :path in agents (#1077) # What does this PR do? Remove :path in agents, we cannot have :path in params inside endpoints except last one ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] ``` llama stack run ``` [//]: # (## Documentation) --- llama_stack/apis/agents/agents.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index b20145be9..e2901448b 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -318,7 +318,7 @@ class Agents(Protocol): agent_config: AgentConfig, ) -> AgentCreateResponse: ... - @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}/turn", method="POST") + @webmethod(route="/agents/{agent_id}/session/{session_id}/turn", method="POST") async def create_agent_turn( self, agent_id: str, @@ -336,7 +336,7 @@ class Agents(Protocol): ) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ... @webmethod( - route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}", + route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET", ) async def get_agents_turn( @@ -347,7 +347,7 @@ class Agents(Protocol): ) -> Turn: ... @webmethod( - route="/agents/{agent_id:path}/session/{session_id:path}/turn/{turn_id:path}/step/{step_id:path}", + route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}", method="GET", ) async def get_agents_step( @@ -358,14 +358,14 @@ class Agents(Protocol): step_id: str, ) -> AgentStepResponse: ... - @webmethod(route="/agents/{agent_id:path}/session", method="POST") + @webmethod(route="/agents/{agent_id}/session", method="POST") async def create_agent_session( self, agent_id: str, session_name: str, ) -> AgentSessionCreateResponse: ... - @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="GET") + @webmethod(route="/agents/{agent_id}/session/{session_id}", method="GET") async def get_agents_session( self, session_id: str, @@ -373,14 +373,14 @@ class Agents(Protocol): turn_ids: Optional[List[str]] = None, ) -> Session: ... - @webmethod(route="/agents/{agent_id:path}/session/{session_id:path}", method="DELETE") + @webmethod(route="/agents/{agent_id}/session/{session_id}", method="DELETE") async def delete_agents_session( self, session_id: str, agent_id: str, ) -> None: ... - @webmethod(route="/agents/{agent_id:path}", method="DELETE") + @webmethod(route="/agents/{agent_id}", method="DELETE") async def delete_agent( self, agent_id: str, From e4a1579e637622455c06d12cbf2f3df4e5b1dd84 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 13 Feb 2025 19:06:21 +0100 Subject: [PATCH 25/27] build: format codebase imports using ruff linter (#1028) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? - Configured ruff linter to automatically fix import sorting issues. - Set --exit-non-zero-on-fix to ensure non-zero exit code when fixes are applied. - Enabled the 'I' selection to focus on import-related linting rules. - Ran the linter, and formatted all codebase imports accordingly. - Removed the black dep from the "dev" group since we use ruff Signed-off-by: Sébastien Han [//]: # (If resolving an issue, uncomment and update the line below) [//]: # (Closes #[issue-number]) ## Test Plan [Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*] [//]: # (## Documentation) [//]: # (- [ ] Added a Changelog entry if the change is significant) Signed-off-by: Sébastien Han --- .pre-commit-config.yaml | 4 +++- llama_stack/apis/agents/agents.py | 4 ++-- llama_stack/apis/agents/event_logger.py | 1 - llama_stack/apis/common/content_types.py | 1 - llama_stack/apis/common/deployment_types.py | 1 - .../scoring_functions/scoring_functions.py | 2 +- .../synthetic_data_generation.py | 2 -- llama_stack/apis/tools/__init__.py | 2 +- llama_stack/apis/tools/rag_tool.py | 2 +- llama_stack/apis/tools/tools.py | 2 +- llama_stack/cli/download.py | 2 -- llama_stack/cli/model/describe.py | 1 - llama_stack/cli/model/model.py | 1 - llama_stack/cli/model/prompt_format.py | 2 +- llama_stack/cli/model/safety_models.py | 1 - llama_stack/cli/stack/_build.py | 5 ++-- llama_stack/cli/tests/test_stack_config.py | 1 + llama_stack/distribution/build.py | 4 ---- llama_stack/distribution/client.py | 4 +--- llama_stack/distribution/configure.py | 4 +--- llama_stack/distribution/library_client.py | 24 +++++++++---------- llama_stack/distribution/routers/__init__.py | 1 - llama_stack/distribution/routers/routers.py | 2 +- llama_stack/distribution/server/endpoints.py | 3 --- llama_stack/distribution/server/server.py | 5 ++-- .../distribution/store/tests/test_registry.py | 2 +- llama_stack/distribution/ui/modules/api.py | 1 - .../ui/page/distribution/resources.py | 1 - .../ui/page/evaluations/app_eval.py | 1 - .../ui/page/evaluations/native_eval.py | 2 -- .../distribution/ui/page/playground/rag.py | 1 - llama_stack/distribution/utils/config_dirs.py | 1 - .../distribution/utils/prompt_for_config.py | 4 +--- llama_stack/providers/datatypes.py | 1 - .../agents/meta_reference/agent_instance.py | 2 +- .../inline/agents/meta_reference/safety.py | 2 -- .../meta_reference/tests/test_chat_agent.py | 1 - .../inline/datasetio/localfs/datasetio.py | 2 -- .../inline/eval/meta_reference/eval.py | 2 -- .../inline/inference/meta_reference/config.py | 1 - .../inference/meta_reference/generation.py | 2 -- .../inference/meta_reference/inference.py | 2 +- .../meta_reference/parallel_utils.py | 5 +--- .../meta_reference/quantization/fp8_impls.py | 3 +-- .../quantization/fp8_txest_disabled.py | 6 ++--- .../meta_reference/quantization/loader.py | 7 +----- .../scripts/quantize_checkpoint.py | 2 -- .../sentence_transformers.py | 2 +- .../providers/inline/inference/vllm/vllm.py | 2 +- .../post_training/torchtune/common/utils.py | 2 -- .../post_training/torchtune/datasets/sft.py | 1 - .../recipes/lora_finetuning_single_device.py | 7 ++---- .../safety/code_scanner/code_scanner.py | 1 - .../inline/safety/llama_guard/llama_guard.py | 4 ---- .../safety/prompt_guard/prompt_guard.py | 2 -- .../providers/inline/scoring/basic/scoring.py | 2 +- .../basic/scoring_fn/equality_scoring_fn.py | 1 - .../basic/scoring_fn/fn_defs/equality.py | 1 - .../basic/scoring_fn/fn_defs/subset_of.py | 1 - .../scoring_fn/regex_parser_scoring_fn.py | 1 - .../inline/scoring/braintrust/braintrust.py | 4 +--- .../scoring_fn/fn_defs/answer_correctness.py | 1 - .../scoring_fn/fn_defs/factuality.py | 1 - .../inline/scoring/llm_as_judge/scoring.py | 2 -- .../scoring_fn/fn_defs/llm_as_judge_base.py | 1 - .../scoring_fn/llm_as_judge_scoring_fn.py | 4 ---- .../inline/telemetry/sample/sample.py | 1 + .../code_interpreter/code_env_prefix.py | 4 ++++ .../tool_runtime/rag/context_retriever.py | 1 - .../inline/tool_runtime/rag/memory.py | 2 +- .../inline/vector_io/faiss/__init__.py | 1 + .../providers/inline/vector_io/faiss/faiss.py | 2 -- .../inline/vector_io/sqlite_vec/__init__.py | 2 ++ .../inline/vector_io/sqlite_vec/config.py | 3 ++- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 9 +++---- .../providers/remote/agents/sample/sample.py | 1 + .../datasetio/huggingface/huggingface.py | 1 - .../remote/inference/bedrock/bedrock.py | 4 ++-- .../remote/inference/cerebras/cerebras.py | 2 +- .../remote/inference/databricks/databricks.py | 2 +- .../remote/inference/fireworks/fireworks.py | 2 +- .../providers/remote/inference/groq/groq.py | 2 +- .../remote/inference/groq/groq_utils.py | 6 ++--- .../remote/inference/nvidia/nvidia.py | 2 +- .../remote/inference/nvidia/openai_utils.py | 19 ++++++++++++++- .../remote/inference/ollama/config.py | 1 - .../remote/inference/ollama/ollama.py | 4 ++-- .../remote/inference/runpod/runpod.py | 2 -- .../remote/inference/sambanova/sambanova.py | 2 +- .../remote/inference/sample/sample.py | 1 + .../providers/remote/inference/tgi/tgi.py | 4 ++-- .../remote/inference/together/together.py | 2 +- .../providers/remote/inference/vllm/vllm.py | 20 ++++++++-------- .../remote/safety/bedrock/bedrock.py | 3 --- .../providers/remote/safety/sample/sample.py | 1 + .../model_context_protocol/__init__.py | 1 - .../remote/vector_io/chroma/chroma.py | 1 + .../remote/vector_io/pgvector/pgvector.py | 4 +--- .../remote/vector_io/qdrant/qdrant.py | 1 + .../remote/vector_io/sample/sample.py | 1 + .../remote/vector_io/weaviate/weaviate.py | 1 - .../providers/tests/agents/conftest.py | 1 - .../providers/tests/agents/test_agents.py | 1 - .../tests/agents/test_persistence.py | 1 - llama_stack/providers/tests/conftest.py | 2 -- .../providers/tests/datasetio/fixtures.py | 1 - llama_stack/providers/tests/eval/conftest.py | 4 +--- llama_stack/providers/tests/eval/fixtures.py | 2 +- llama_stack/providers/tests/eval/test_eval.py | 2 +- .../providers/tests/inference/fixtures.py | 2 -- .../tests/inference/groq/test_groq_utils.py | 5 +++- .../tests/inference/groq/test_init.py | 2 +- .../inference/test_model_registration.py | 1 - .../tests/inference/test_text_inference.py | 3 --- .../providers/tests/post_training/conftest.py | 2 -- .../providers/tests/post_training/fixtures.py | 3 --- llama_stack/providers/tests/report.py | 2 -- .../providers/tests/safety/conftest.py | 2 -- .../providers/tests/safety/fixtures.py | 3 --- .../providers/tests/scoring/conftest.py | 1 - .../providers/tests/scoring/fixtures.py | 2 +- .../providers/tests/vector_io/conftest.py | 2 -- .../providers/tests/vector_io/fixtures.py | 1 - .../tests/vector_io/test_vector_io.py | 2 -- .../tests/vector_io/test_vector_store.py | 3 +-- .../utils/common/data_schema_validator.py | 1 - .../providers/utils/datasetio/url_utils.py | 1 - .../utils/inference/model_registry.py | 1 - .../utils/inference/openai_compat.py | 3 --- .../utils/inference/prompt_adapter.py | 4 ++-- .../providers/utils/kvstore/sqlite/sqlite.py | 1 - .../providers/utils/memory/vector_store.py | 8 +++---- llama_stack/scripts/distro_codegen.py | 3 +-- llama_stack/templates/dell/dell.py | 1 - llama_stack/templates/sambanova/sambanova.py | 1 - llama_stack/templates/template.py | 2 +- tests/client-sdk/agents/test_agents.py | 6 +++-- tests/client-sdk/conftest.py | 4 ++-- tests/client-sdk/report.py | 8 +++---- .../client-sdk/tool_runtime/test_rag_tool.py | 1 - 140 files changed, 139 insertions(+), 243 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bca91081f..a7ece3b25 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -29,10 +29,12 @@ repos: - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.9.4 hooks: + # Run the linter with import sorting. - id: ruff args: [ --fix, - --exit-non-zero-on-fix + --exit-non-zero-on-fix, + --select, I, ] - id: ruff-format diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index e2901448b..106d34584 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -15,14 +15,14 @@ from typing import ( Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) from llama_models.schema_utils import json_schema_type, register_schema, webmethod from pydantic import BaseModel, ConfigDict, Field -from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL +from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, diff --git a/llama_stack/apis/agents/event_logger.py b/llama_stack/apis/agents/event_logger.py index 021cb6e1a..835ce4cee 100644 --- a/llama_stack/apis/agents/event_logger.py +++ b/llama_stack/apis/agents/event_logger.py @@ -13,7 +13,6 @@ from termcolor import cprint from llama_stack.apis.agents import AgentTurnResponseEventType, StepType from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ToolResponseMessage - from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) diff --git a/llama_stack/apis/common/content_types.py b/llama_stack/apis/common/content_types.py index 8e56f59b1..e648f9a19 100644 --- a/llama_stack/apis/common/content_types.py +++ b/llama_stack/apis/common/content_types.py @@ -8,7 +8,6 @@ from enum import Enum from typing import Annotated, List, Literal, Optional, Union from llama_models.llama3.api.datatypes import ToolCall - from llama_models.schema_utils import json_schema_type, register_schema from pydantic import BaseModel, Field, model_validator diff --git a/llama_stack/apis/common/deployment_types.py b/llama_stack/apis/common/deployment_types.py index 24de0cc91..16a5c8ad6 100644 --- a/llama_stack/apis/common/deployment_types.py +++ b/llama_stack/apis/common/deployment_types.py @@ -8,7 +8,6 @@ from enum import Enum from typing import Any, Dict, Optional from llama_models.schema_utils import json_schema_type - from pydantic import BaseModel from llama_stack.apis.common.content_types import URL diff --git a/llama_stack/apis/scoring_functions/scoring_functions.py b/llama_stack/apis/scoring_functions/scoring_functions.py index 3fa40ffbf..fece50fbd 100644 --- a/llama_stack/apis/scoring_functions/scoring_functions.py +++ b/llama_stack/apis/scoring_functions/scoring_functions.py @@ -12,8 +12,8 @@ from typing import ( Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) from llama_models.schema_utils import json_schema_type, register_schema, webmethod diff --git a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py index 13b209912..a61fb0cf2 100644 --- a/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py +++ b/llama_stack/apis/synthetic_data_generation/synthetic_data_generation.py @@ -5,11 +5,9 @@ # the root directory of this source tree. from enum import Enum - from typing import Any, Dict, List, Optional, Protocol, Union from llama_models.schema_utils import json_schema_type, webmethod - from pydantic import BaseModel from llama_stack.apis.inference import Message diff --git a/llama_stack/apis/tools/__init__.py b/llama_stack/apis/tools/__init__.py index 8cd798ebf..be8846ba2 100644 --- a/llama_stack/apis/tools/__init__.py +++ b/llama_stack/apis/tools/__init__.py @@ -4,5 +4,5 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from .tools import * # noqa: F401 F403 from .rag_tool import * # noqa: F401 F403 +from .tools import * # noqa: F401 F403 diff --git a/llama_stack/apis/tools/rag_tool.py b/llama_stack/apis/tools/rag_tool.py index 2e9bf9c51..2e6b43eb8 100644 --- a/llama_stack/apis/tools/rag_tool.py +++ b/llama_stack/apis/tools/rag_tool.py @@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, register_schema, webmeth from pydantic import BaseModel, Field from typing_extensions import Annotated, Protocol, runtime_checkable -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol diff --git a/llama_stack/apis/tools/tools.py b/llama_stack/apis/tools/tools.py index a8e946b08..2a407ca00 100644 --- a/llama_stack/apis/tools/tools.py +++ b/llama_stack/apis/tools/tools.py @@ -11,7 +11,7 @@ from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel, Field from typing_extensions import Protocol, runtime_checkable -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol diff --git a/llama_stack/cli/download.py b/llama_stack/cli/download.py index 379ac49ca..3ea534277 100644 --- a/llama_stack/cli/download.py +++ b/llama_stack/cli/download.py @@ -16,11 +16,9 @@ from pathlib import Path from typing import Dict, List, Optional import httpx - from llama_models.datatypes import Model from llama_models.sku_list import LlamaDownloadInfo from pydantic import BaseModel, ConfigDict - from rich.console import Console from rich.progress import ( BarColumn, diff --git a/llama_stack/cli/model/describe.py b/llama_stack/cli/model/describe.py index fc0190ca8..a25513633 100644 --- a/llama_stack/cli/model/describe.py +++ b/llama_stack/cli/model/describe.py @@ -8,7 +8,6 @@ import argparse import json from llama_models.sku_list import resolve_model - from termcolor import colored from llama_stack.cli.subcommand import Subcommand diff --git a/llama_stack/cli/model/model.py b/llama_stack/cli/model/model.py index 02e7f216f..3f8f55773 100644 --- a/llama_stack/cli/model/model.py +++ b/llama_stack/cli/model/model.py @@ -11,7 +11,6 @@ from llama_stack.cli.model.download import ModelDownload from llama_stack.cli.model.list import ModelList from llama_stack.cli.model.prompt_format import ModelPromptFormat from llama_stack.cli.model.verify_download import ModelVerifyDownload - from llama_stack.cli.subcommand import Subcommand diff --git a/llama_stack/cli/model/prompt_format.py b/llama_stack/cli/model/prompt_format.py index 388a63a42..2e1e1601e 100644 --- a/llama_stack/cli/model/prompt_format.py +++ b/llama_stack/cli/model/prompt_format.py @@ -8,7 +8,7 @@ import argparse import textwrap from io import StringIO -from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily +from llama_models.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family from llama_stack.cli.subcommand import Subcommand diff --git a/llama_stack/cli/model/safety_models.py b/llama_stack/cli/model/safety_models.py index 424ec367b..2321c4615 100644 --- a/llama_stack/cli/model/safety_models.py +++ b/llama_stack/cli/model/safety_models.py @@ -9,7 +9,6 @@ from typing import Any, Dict, Optional from llama_models.datatypes import CheckpointQuantizationFormat from llama_models.llama3.api.datatypes import SamplingParams from llama_models.sku_list import LlamaDownloadInfo - from pydantic import BaseModel, ConfigDict, Field diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 65d37e9da..76f03aa5c 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -21,12 +21,11 @@ from prompt_toolkit.validation import Validator from termcolor import cprint from llama_stack.cli.table import print_table - from llama_stack.distribution.build import ( + SERVER_DEPENDENCIES, + ImageType, build_image, get_provider_dependencies, - ImageType, - SERVER_DEPENDENCIES, ) from llama_stack.distribution.datatypes import ( BuildConfig, diff --git a/llama_stack/cli/tests/test_stack_config.py b/llama_stack/cli/tests/test_stack_config.py index e1b9b23c5..2b7b2b210 100644 --- a/llama_stack/cli/tests/test_stack_config.py +++ b/llama_stack/cli/tests/test_stack_config.py @@ -8,6 +8,7 @@ from datetime import datetime import pytest import yaml + from llama_stack.distribution.configure import ( LLAMA_STACK_RUN_CONFIG_VERSION, parse_and_maybe_upgrade_config, diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index b898312f4..9422c8457 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -8,7 +8,6 @@ import importlib.resources import logging import sys from enum import Enum - from pathlib import Path from typing import Dict, List @@ -16,11 +15,8 @@ from pydantic import BaseModel from termcolor import cprint from llama_stack.distribution.datatypes import BuildConfig, Provider - from llama_stack.distribution.distribution import get_provider_registry - from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR - from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.providers.datatypes import Api diff --git a/llama_stack/distribution/client.py b/llama_stack/distribution/client.py index 8ed82f83e..b1d174ede 100644 --- a/llama_stack/distribution/client.py +++ b/llama_stack/distribution/client.py @@ -5,18 +5,16 @@ # the root directory of this source tree. import inspect - import json from collections.abc import AsyncIterator from enum import Enum -from typing import Any, get_args, get_origin, Type, Union +from typing import Any, Type, Union, get_args, get_origin import httpx from pydantic import BaseModel, parse_obj_as from termcolor import cprint from llama_stack.apis.version import LLAMA_STACK_API_VERSION - from llama_stack.providers.datatypes import RemoteProviderConfig _CLIENT_CLASSES = {} diff --git a/llama_stack/distribution/configure.py b/llama_stack/distribution/configure.py index 054f54864..825846a23 100644 --- a/llama_stack/distribution/configure.py +++ b/llama_stack/distribution/configure.py @@ -5,12 +5,11 @@ # the root directory of this source tree. import logging import textwrap - from typing import Any, Dict from llama_stack.distribution.datatypes import ( - DistributionSpec, LLAMA_STACK_RUN_CONFIG_VERSION, + DistributionSpec, Provider, StackRunConfig, ) @@ -20,7 +19,6 @@ from llama_stack.distribution.distribution import ( ) from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.prompt_for_config import prompt_for_config - from llama_stack.providers.datatypes import Api, ProviderSpec logger = logging.getLogger(__name__) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 2c0f73974..55a15e5e9 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -13,10 +13,21 @@ import re from concurrent.futures import ThreadPoolExecutor from enum import Enum from pathlib import Path -from typing import Any, get_args, get_origin, Optional, TypeVar +from typing import Any, Optional, TypeVar, get_args, get_origin import httpx import yaml +from llama_stack_client import ( + NOT_GIVEN, + APIResponse, + AsyncAPIResponse, + AsyncLlamaStackClient, + AsyncStream, + LlamaStackClient, +) +from pydantic import BaseModel, TypeAdapter +from rich.console import Console +from termcolor import cprint from llama_stack.distribution.build import print_pip_install_help from llama_stack.distribution.configure import parse_and_maybe_upgrade_config @@ -35,17 +46,6 @@ from llama_stack.providers.utils.telemetry.tracing import ( setup_logger, start_trace, ) -from llama_stack_client import ( - APIResponse, - AsyncAPIResponse, - AsyncLlamaStackClient, - AsyncStream, - LlamaStackClient, - NOT_GIVEN, -) -from pydantic import BaseModel, TypeAdapter -from rich.console import Console -from termcolor import cprint T = TypeVar("T") diff --git a/llama_stack/distribution/routers/__init__.py b/llama_stack/distribution/routers/__init__.py index 156cda385..18197ca7f 100644 --- a/llama_stack/distribution/routers/__init__.py +++ b/llama_stack/distribution/routers/__init__.py @@ -7,7 +7,6 @@ from typing import Any, Dict from llama_stack.distribution.datatypes import RoutedProtocol - from llama_stack.distribution.store import DistributionRegistry from llama_stack.providers.datatypes import Api, RoutingTable diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 6cddcf73c..e716e44b0 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -6,7 +6,7 @@ from typing import Any, AsyncGenerator, Dict, List, Optional -from llama_stack.apis.common.content_types import InterleavedContent, URL +from llama_stack.apis.common.content_types import URL, InterleavedContent from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.eval import ( AppEvalTaskConfig, diff --git a/llama_stack/distribution/server/endpoints.py b/llama_stack/distribution/server/endpoints.py index 45f1a2831..812f59ffd 100644 --- a/llama_stack/distribution/server/endpoints.py +++ b/llama_stack/distribution/server/endpoints.py @@ -10,11 +10,8 @@ from typing import Dict, List from pydantic import BaseModel from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup - from llama_stack.apis.version import LLAMA_STACK_API_VERSION - from llama_stack.distribution.resolver import api_protocol_map - from llama_stack.providers.datatypes import Api diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index bb735268b..0d234d506 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -7,9 +7,9 @@ import argparse import asyncio import functools -import logging import inspect import json +import logging import os import signal import sys @@ -21,7 +21,8 @@ from pathlib import Path from typing import Any, List, Union import yaml -from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request +from fastapi import Body, FastAPI, HTTPException, Request +from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel, ValidationError diff --git a/llama_stack/distribution/store/tests/test_registry.py b/llama_stack/distribution/store/tests/test_registry.py index 1671cd30b..1ddba7472 100644 --- a/llama_stack/distribution/store/tests/test_registry.py +++ b/llama_stack/distribution/store/tests/test_registry.py @@ -8,9 +8,9 @@ import os import pytest import pytest_asyncio + from llama_stack.apis.inference import Model from llama_stack.apis.vector_dbs import VectorDB - from llama_stack.distribution.store.registry import ( CachedDiskDistributionRegistry, DiskDistributionRegistry, diff --git a/llama_stack/distribution/ui/modules/api.py b/llama_stack/distribution/ui/modules/api.py index 5f07a27c7..40caccda0 100644 --- a/llama_stack/distribution/ui/modules/api.py +++ b/llama_stack/distribution/ui/modules/api.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os - from typing import Optional from llama_stack_client import LlamaStackClient diff --git a/llama_stack/distribution/ui/page/distribution/resources.py b/llama_stack/distribution/ui/page/distribution/resources.py index 38d494570..94b840bcb 100644 --- a/llama_stack/distribution/ui/page/distribution/resources.py +++ b/llama_stack/distribution/ui/page/distribution/resources.py @@ -10,7 +10,6 @@ from page.distribution.models import models from page.distribution.scoring_functions import scoring_functions from page.distribution.shields import shields from page.distribution.vector_dbs import vector_dbs - from streamlit_option_menu import option_menu diff --git a/llama_stack/distribution/ui/page/evaluations/app_eval.py b/llama_stack/distribution/ui/page/evaluations/app_eval.py index 9b684ab80..26bc28451 100644 --- a/llama_stack/distribution/ui/page/evaluations/app_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/app_eval.py @@ -8,7 +8,6 @@ import json import pandas as pd import streamlit as st - from modules.api import llama_stack_api from modules.utils import process_dataset diff --git a/llama_stack/distribution/ui/page/evaluations/native_eval.py b/llama_stack/distribution/ui/page/evaluations/native_eval.py index c4a44990f..112d9cff0 100644 --- a/llama_stack/distribution/ui/page/evaluations/native_eval.py +++ b/llama_stack/distribution/ui/page/evaluations/native_eval.py @@ -7,9 +7,7 @@ import json import pandas as pd - import streamlit as st - from modules.api import llama_stack_api diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 8b30987cf..d84418241 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -9,7 +9,6 @@ from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types.agent_create_params import AgentConfig from llama_stack_client.types.memory_insert_params import Document - from modules.api import llama_stack_api from modules.utils import data_url_from_file diff --git a/llama_stack/distribution/utils/config_dirs.py b/llama_stack/distribution/utils/config_dirs.py index eca59493f..e512c3576 100644 --- a/llama_stack/distribution/utils/config_dirs.py +++ b/llama_stack/distribution/utils/config_dirs.py @@ -7,7 +7,6 @@ import os from pathlib import Path - LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/"))) DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions" diff --git a/llama_stack/distribution/utils/prompt_for_config.py b/llama_stack/distribution/utils/prompt_for_config.py index 6a6223cc9..9b2b99022 100644 --- a/llama_stack/distribution/utils/prompt_for_config.py +++ b/llama_stack/distribution/utils/prompt_for_config.py @@ -8,13 +8,11 @@ import inspect import json import logging from enum import Enum - -from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union +from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin from pydantic import BaseModel from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefinedType - from typing_extensions import Annotated log = logging.getLogger(__name__) diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 8df91cce6..ccdaf76e7 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -11,7 +11,6 @@ from llama_models.schema_utils import json_schema_type from pydantic import BaseModel, Field from llama_stack.apis.datasets import Dataset - from llama_stack.apis.datatypes import Api from llama_stack.apis.eval_tasks import EvalTask from llama_stack.apis.models import Model diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 2f397f438..8ba7885cd 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -42,10 +42,10 @@ from llama_stack.apis.agents import ( Turn, ) from llama_stack.apis.common.content_types import ( + URL, TextContentItem, ToolCallDelta, ToolCallParseStatus, - URL, ) from llama_stack.apis.inference import ( ChatCompletionResponseEventType, diff --git a/llama_stack/providers/inline/agents/meta_reference/safety.py b/llama_stack/providers/inline/agents/meta_reference/safety.py index 69439522b..30ce52e3b 100644 --- a/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -6,11 +6,9 @@ import asyncio import logging - from typing import List from llama_stack.apis.inference import Message - from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel log = logging.getLogger(__name__) diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index b62bc5fee..4e3951ad3 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -41,7 +41,6 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse - from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 54afae839..491f03f72 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -15,14 +15,12 @@ import pandas from llama_stack.apis.common.content_types import URL from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasets import Dataset - from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl from .config import LocalFSDatasetIOConfig - DATASETS_PREFIX = "localfs_datasets:" diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 1db627007..1c44caf7f 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -15,7 +15,6 @@ from llama_stack.apis.inference import Inference, UserMessage from llama_stack.apis.scoring import Scoring from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import EvalTasksProtocolPrivate - from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) @@ -28,7 +27,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from .....apis.common.job_types import Job from .....apis.eval.eval import Eval, EvalTaskConfig, EvaluateResponse, JobStatus - from .config import MetaReferenceEvalConfig EVAL_TASKS_PREFIX = "eval_tasks:" diff --git a/llama_stack/providers/inline/inference/meta_reference/config.py b/llama_stack/providers/inline/inference/meta_reference/config.py index 57939abaa..9e5f7747e 100644 --- a/llama_stack/providers/inline/inference/meta_reference/config.py +++ b/llama_stack/providers/inline/inference/meta_reference/config.py @@ -9,7 +9,6 @@ from typing import Any, Dict, Optional from pydantic import BaseModel, field_validator from llama_stack.apis.inference import QuantizationConfig - from llama_stack.providers.utils.inference import supported_inference_models diff --git a/llama_stack/providers/inline/inference/meta_reference/generation.py b/llama_stack/providers/inline/inference/meta_reference/generation.py index 51c10b0a8..e60c3b1be 100644 --- a/llama_stack/providers/inline/inference/meta_reference/generation.py +++ b/llama_stack/providers/inline/inference/meta_reference/generation.py @@ -37,7 +37,6 @@ from llama_models.llama3.reference_impl.multimodal.model import ( CrossAttentionTransformer, ) from llama_models.sku_list import resolve_model - from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData from pydantic import BaseModel @@ -47,7 +46,6 @@ from llama_stack.apis.inference import ( ResponseFormat, ResponseFormatType, ) - from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.utils.inference.prompt_adapter import ( ChatCompletionRequestWithRawContent, diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 3caf4e2a5..61f0ee3f4 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -46,8 +46,8 @@ from llama_stack.providers.utils.inference.embedding_mixin import ( SentenceTransformerEmbeddingMixin, ) from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.prompt_adapter import ( augment_content_with_response_format_prompt, diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index b8efddcbd..711a4632d 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -22,16 +22,13 @@ from typing import Callable, Generator, Literal, Optional, Union import torch import zmq - from fairscale.nn.model_parallel.initialize import ( get_model_parallel_group, get_model_parallel_rank, get_model_parallel_src_rank, ) - from pydantic import BaseModel, Field - -from torch.distributed.launcher.api import elastic_launch, LaunchConfig +from torch.distributed.launcher.api import LaunchConfig, elastic_launch from typing_extensions import Annotated from llama_stack.providers.utils.inference.prompt_adapter import ( diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py index f5235d6c9..2b5e135b4 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls.py @@ -8,7 +8,6 @@ # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. import collections - import logging from typing import Optional, Type @@ -23,7 +22,7 @@ except ImportError: raise import torch -from torch import nn, Tensor +from torch import Tensor, nn class Fp8ScaledWeights: diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py index 8f52d8c04..014a26f09 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/fp8_txest_disabled.py @@ -10,9 +10,9 @@ import unittest import torch - -from fp8_impls import ffn_swiglu_fp8_dynamic, FfnQuantizeMode, quantize_fp8 -from hypothesis import given, settings, strategies as st +from fp8_impls import FfnQuantizeMode, ffn_swiglu_fp8_dynamic, quantize_fp8 +from hypothesis import given, settings +from hypothesis import strategies as st from torch import Tensor diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py index 955527ff8..9be35ae70 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/loader.py @@ -12,18 +12,13 @@ import os from typing import Any, Dict, List, Optional import torch - from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region - from llama_models.datatypes import CheckpointQuantizationFormat - from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock from llama_models.sku_list import resolve_model - -from torch import nn, Tensor - +from torch import Tensor, nn from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear from llama_stack.apis.inference import QuantizationType diff --git a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py index 4764d59b1..8bff70464 100644 --- a/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py +++ b/llama_stack/providers/inline/inference/meta_reference/quantization/scripts/quantize_checkpoint.py @@ -16,14 +16,12 @@ from pathlib import Path from typing import Optional import fire - import torch from fairscale.nn.model_parallel.initialize import ( get_model_parallel_rank, initialize_model_parallel, model_parallel_is_initialized, ) - from llama_models.llama3.api.args import ModelArgs from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock diff --git a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py index d34befbd9..6a83836e6 100644 --- a/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py +++ b/llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers.py @@ -15,9 +15,9 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, - ToolConfig, ) from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.utils.inference.embedding_mixin import ( diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 77c95cc7e..e75a9aac3 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -37,9 +37,9 @@ from llama_stack.apis.inference import ( from llama_stack.apis.models import Model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, ) diff --git a/llama_stack/providers/inline/post_training/torchtune/common/utils.py b/llama_stack/providers/inline/post_training/torchtune/common/utils.py index 88011ead4..735af8c79 100644 --- a/llama_stack/providers/inline/post_training/torchtune/common/utils.py +++ b/llama_stack/providers/inline/post_training/torchtune/common/utils.py @@ -15,10 +15,8 @@ from typing import Any, Callable, Dict import torch from llama_models.datatypes import Model from llama_models.sku_list import resolve_model - from pydantic import BaseModel from torchtune.data._messages import InputOutputToMessages, ShareGPTToMessages - from torchtune.models.llama3 import llama3_tokenizer from torchtune.models.llama3._tokenizer import Llama3Tokenizer from torchtune.models.llama3_1 import lora_llama3_1_8b diff --git a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py index 82e6645d2..b556b59a6 100644 --- a/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py +++ b/llama_stack/providers/inline/post_training/torchtune/datasets/sft.py @@ -13,7 +13,6 @@ from typing import Any, Dict, List, Mapping import numpy as np - from torch.utils.data import Dataset from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX from torchtune.data._messages import validate_messages diff --git a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py index dbb3f714a..ef379aff2 100644 --- a/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py +++ b/llama_stack/providers/inline/post_training/torchtune/recipes/lora_finetuning_single_device.py @@ -18,9 +18,9 @@ from llama_models.sku_list import resolve_model from torch import nn from torch.optim import Optimizer from torch.utils.data import DataLoader, DistributedSampler -from torchtune import modules, training, utils as torchtune_utils +from torchtune import modules, training +from torchtune import utils as torchtune_utils from torchtune.data import padded_collate_sft - from torchtune.modules.loss import CEWithChunkedOutputLoss from torchtune.modules.peft import ( get_adapter_params, @@ -44,14 +44,11 @@ from llama_stack.apis.post_training import ( OptimizerConfig, TrainingConfig, ) - from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR - from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.inline.post_training.common.validator import ( validate_input_dataset_schema, ) - from llama_stack.providers.inline.post_training.torchtune.common import utils from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import ( TorchtuneCheckpointer, diff --git a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py index 22af7ef23..606d11d2c 100644 --- a/llama_stack/providers/inline/safety/code_scanner/code_scanner.py +++ b/llama_stack/providers/inline/safety/code_scanner/code_scanner.py @@ -21,7 +21,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import CodeScannerConfig - log = logging.getLogger(__name__) ALLOWED_CODE_SCANNER_MODEL_IDS = [ diff --git a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py index e5168fb00..32d6d5100 100644 --- a/llama_stack/providers/inline/safety/llama_guard/llama_guard.py +++ b/llama_stack/providers/inline/safety/llama_guard/llama_guard.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import re - from string import Template from typing import Any, Dict, List, Optional @@ -25,10 +24,8 @@ from llama_stack.apis.safety import ( SafetyViolation, ViolationLevel, ) - from llama_stack.apis.shields import Shield from llama_stack.distribution.datatypes import Api - from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, @@ -36,7 +33,6 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import LlamaGuardConfig - CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" SAFE_RESPONSE = "safe" diff --git a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py index 76d34e549..fce3e3d14 100644 --- a/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py +++ b/llama_stack/providers/inline/safety/prompt_guard/prompt_guard.py @@ -8,7 +8,6 @@ import logging from typing import Any, Dict, List import torch - from transformers import AutoModelForSequenceClassification, AutoTokenizer from llama_stack.apis.inference import Message @@ -19,7 +18,6 @@ from llama_stack.apis.safety import ( ViolationLevel, ) from llama_stack.apis.shields import Shield - from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.utils.inference.prompt_adapter import ( diff --git a/llama_stack/providers/inline/scoring/basic/scoring.py b/llama_stack/providers/inline/scoring/basic/scoring.py index 24ce11872..13cd78243 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring.py +++ b/llama_stack/providers/inline/scoring/basic/scoring.py @@ -14,13 +14,13 @@ from llama_stack.apis.scoring import ( ScoringResult, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams - from llama_stack.distribution.datatypes import Api from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.utils.common.data_schema_validator import ( get_valid_schemas, validate_dataset_schema, ) + from .config import BasicScoringConfig from .scoring_fn.equality_scoring_fn import EqualityScoringFn from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py index ad2037bdf..0bd6bdd48 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/equality_scoring_fn.py @@ -7,7 +7,6 @@ from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow - from llama_stack.apis.scoring_functions import ScoringFnParams from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py index 7973eb939..9b24ff791 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/equality.py @@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import ( ScoringFn, ) - equality = ScoringFn( identifier="basic::equality", description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.", diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py index 0281e81b9..9cae66fa6 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/fn_defs/subset_of.py @@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import ( ScoringFn, ) - subset_of = ScoringFn( identifier="basic::subset_of", description="Returns 1.0 if the expected is included in generated, 0.0 otherwise.", diff --git a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py index 4fcfdba76..0606a9581 100644 --- a/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/basic/scoring_fn/regex_parser_scoring_fn.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import re - from typing import Any, Dict, Optional from llama_stack.apis.scoring import ScoringResultRow diff --git a/llama_stack/providers/inline/scoring/braintrust/braintrust.py b/llama_stack/providers/inline/scoring/braintrust/braintrust.py index ff3207e32..be0f023f3 100644 --- a/llama_stack/providers/inline/scoring/braintrust/braintrust.py +++ b/llama_stack/providers/inline/scoring/braintrust/braintrust.py @@ -29,9 +29,7 @@ from llama_stack.apis.scoring import ( ScoringResultRow, ) from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams - from llama_stack.distribution.datatypes import Api - from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate from llama_stack.providers.utils.common.data_schema_validator import ( @@ -39,8 +37,8 @@ from llama_stack.providers.utils.common.data_schema_validator import ( validate_dataset_schema, validate_row_schema, ) - from llama_stack.providers.utils.scoring.aggregation_utils import aggregate_metrics + from .config import BraintrustScoringConfig from .scoring_fn.fn_defs.answer_correctness import answer_correctness_fn_def from .scoring_fn.fn_defs.answer_relevancy import answer_relevancy_fn_def diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py index 1941417bb..4fe07f822 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/answer_correctness.py @@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import ( ScoringFn, ) - answer_correctness_fn_def = ScoringFn( identifier="braintrust::answer-correctness", description=( diff --git a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py index 3c9fb88de..c621ecf7f 100644 --- a/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py +++ b/llama_stack/providers/inline/scoring/braintrust/scoring_fn/fn_defs/factuality.py @@ -11,7 +11,6 @@ from llama_stack.apis.scoring_functions import ( ScoringFn, ) - factuality_fn_def = ScoringFn( identifier="braintrust::factuality", description=( diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py index 333910c2c..dc562df1f 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring.py @@ -8,7 +8,6 @@ from typing import Any, Dict, List, Optional from llama_stack.apis.datasetio import DatasetIO from llama_stack.apis.datasets import Datasets from llama_stack.apis.inference.inference import Inference - from llama_stack.apis.scoring import ( ScoreBatchResponse, ScoreResponse, @@ -26,7 +25,6 @@ from llama_stack.providers.utils.common.data_schema_validator import ( from .config import LlmAsJudgeScoringConfig from .scoring_fn.llm_as_judge_scoring_fn import LlmAsJudgeScoringFn - LLM_JUDGE_FNS = [LlmAsJudgeScoringFn] diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py index 0b18bac01..205e0bbf3 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/fn_defs/llm_as_judge_base.py @@ -7,7 +7,6 @@ from llama_stack.apis.common.type_system import NumberType from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams, ScoringFn - llm_as_judge_base = ScoringFn( identifier="llm-as-judge::base", description="Llm As Judge Scoring Function", diff --git a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py index 0cf5a042a..457151c04 100644 --- a/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py +++ b/llama_stack/providers/inline/scoring/llm_as_judge/scoring_fn/llm_as_judge_scoring_fn.py @@ -4,18 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import re - from typing import Any, Dict, Optional from llama_stack.apis.inference.inference import Inference - from llama_stack.apis.scoring import ScoringResultRow from llama_stack.apis.scoring_functions import ScoringFnParams - from llama_stack.providers.utils.scoring.base_scoring_fn import RegisteredBaseScoringFn from .fn_defs.llm_as_judge_405b_simpleqa import llm_as_judge_405b_simpleqa - from .fn_defs.llm_as_judge_base import llm_as_judge_base diff --git a/llama_stack/providers/inline/telemetry/sample/sample.py b/llama_stack/providers/inline/telemetry/sample/sample.py index f07a185ef..a4147a1b2 100644 --- a/llama_stack/providers/inline/telemetry/sample/sample.py +++ b/llama_stack/providers/inline/telemetry/sample/sample.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.telemetry import Telemetry + from .config import SampleConfig diff --git a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py index 10f64ec94..f28ae248c 100644 --- a/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py +++ b/llama_stack/providers/inline/tool_runtime/code_interpreter/code_env_prefix.py @@ -82,7 +82,11 @@ import sys as _sys # them with linters - they're used in code_execution.py from contextlib import ( # noqa contextmanager as _contextmanager, +) +from contextlib import ( redirect_stderr as _redirect_stderr, +) +from contextlib import ( redirect_stdout as _redirect_stdout, ) from multiprocessing.connection import Connection as _Connection diff --git a/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py index e77ec76af..be18430e4 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py +++ b/llama_stack/providers/inline/tool_runtime/rag/context_retriever.py @@ -9,7 +9,6 @@ from jinja2 import Template from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.inference import UserMessage - from llama_stack.apis.tools.rag_tool import ( DefaultRAGQueryGeneratorConfig, LLMRAGQueryGeneratorConfig, diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 7b0fff348..5695d4037 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -11,9 +11,9 @@ import string from typing import Any, Dict, List, Optional from llama_stack.apis.common.content_types import ( + URL, InterleavedContent, TextContentItem, - URL, ) from llama_stack.apis.inference import Inference from llama_stack.apis.tools import ( diff --git a/llama_stack/providers/inline/vector_io/faiss/__init__.py b/llama_stack/providers/inline/vector_io/faiss/__init__.py index 15b7259ad..8c075a0f8 100644 --- a/llama_stack/providers/inline/vector_io/faiss/__init__.py +++ b/llama_stack/providers/inline/vector_io/faiss/__init__.py @@ -7,6 +7,7 @@ from typing import Dict from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import FaissImplConfig diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 563d37bb1..565afdcf6 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -8,11 +8,9 @@ import base64 import io import json import logging - from typing import Any, Dict, List, Optional import faiss - import numpy as np from numpy.typing import NDArray diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py index 488a57660..5a2f07012 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/__init__.py @@ -5,7 +5,9 @@ # the root directory of this source tree. from typing import Dict + from llama_stack.providers.datatypes import Api, ProviderSpec + from .config import SQLiteVectorIOConfig diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py index 60fe3ca2a..5a830ff27 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/config.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/config.py @@ -5,9 +5,10 @@ # the root directory of this source tree. # config.py -from pydantic import BaseModel from typing import Any, Dict +from pydantic import BaseModel + from llama_stack.providers.utils.kvstore.config import ( KVStoreConfig, SqliteKVStoreConfig, diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 019d260f8..fcd7cd8f9 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -4,13 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import sqlite3 -import sqlite_vec -import struct import logging +import sqlite3 +import struct +from typing import Any, Dict, List, Optional + import numpy as np +import sqlite_vec from numpy.typing import NDArray -from typing import List, Optional, Dict, Any from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO diff --git a/llama_stack/providers/remote/agents/sample/sample.py b/llama_stack/providers/remote/agents/sample/sample.py index f8b312f1e..02e889496 100644 --- a/llama_stack/providers/remote/agents/sample/sample.py +++ b/llama_stack/providers/remote/agents/sample/sample.py @@ -5,6 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.agents import Agents + from .config import SampleConfig diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index cf17820dd..cd4e7f1f1 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -9,7 +9,6 @@ import datasets as hf_datasets from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult from llama_stack.apis.datasets import Dataset - from llama_stack.providers.datatypes import DatasetsProtocolPrivate from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url from llama_stack.providers.utils.kvstore import kvstore_impl diff --git a/llama_stack/providers/remote/inference/bedrock/bedrock.py b/llama_stack/providers/remote/inference/bedrock/bedrock.py index 54a674d7e..917ac7a25 100644 --- a/llama_stack/providers/remote/inference/bedrock/bedrock.py +++ b/llama_stack/providers/remote/inference/bedrock/bedrock.py @@ -31,13 +31,13 @@ from llama_stack.apis.inference import ( from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_strategy_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + get_sampling_strategy_options, process_chat_completion_response, process_chat_completion_stream_response, ) diff --git a/llama_stack/providers/remote/inference/cerebras/cerebras.py b/llama_stack/providers/remote/inference/cerebras/cerebras.py index 47f208129..2158fc5b4 100644 --- a/llama_stack/providers/remote/inference/cerebras/cerebras.py +++ b/llama_stack/providers/remote/inference/cerebras/cerebras.py @@ -29,8 +29,8 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, diff --git a/llama_stack/providers/remote/inference/databricks/databricks.py b/llama_stack/providers/remote/inference/databricks/databricks.py index ee3c6e99b..d56be1465 100644 --- a/llama_stack/providers/remote/inference/databricks/databricks.py +++ b/llama_stack/providers/remote/inference/databricks/databricks.py @@ -26,8 +26,8 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, diff --git a/llama_stack/providers/remote/inference/fireworks/fireworks.py b/llama_stack/providers/remote/inference/fireworks/fireworks.py index d978cb02e..7e8f85313 100644 --- a/llama_stack/providers/remote/inference/fireworks/fireworks.py +++ b/llama_stack/providers/remote/inference/fireworks/fireworks.py @@ -31,8 +31,8 @@ from llama_stack.apis.inference import ( ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, diff --git a/llama_stack/providers/remote/inference/groq/groq.py b/llama_stack/providers/remote/inference/groq/groq.py index 9b3c1abbf..59ec8b0d2 100644 --- a/llama_stack/providers/remote/inference/groq/groq.py +++ b/llama_stack/providers/remote/inference/groq/groq.py @@ -31,9 +31,9 @@ from llama_stack.apis.inference import ( from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, build_model_alias, build_model_alias_with_just_provider_model_id, - ModelRegistryHelper, ) from .groq_utils import ( diff --git a/llama_stack/providers/remote/inference/groq/groq_utils.py b/llama_stack/providers/remote/inference/groq/groq_utils.py index d00e5c5a9..2445c1b39 100644 --- a/llama_stack/providers/remote/inference/groq/groq_utils.py +++ b/llama_stack/providers/remote/inference/groq/groq_utils.py @@ -24,10 +24,8 @@ from groq.types.chat.chat_completion_user_message_param import ( ) from groq.types.chat.completion_create_params import CompletionCreateParams from groq.types.shared.function_definition import FunctionDefinition - from llama_models.llama3.api.datatypes import ToolParamDefinition - from llama_stack.apis.common.content_types import ( TextDelta, ToolCallDelta, @@ -47,9 +45,9 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_strategy_options, - convert_tool_call, UnparseableToolCall, + convert_tool_call, + get_sampling_strategy_options, ) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index b9b43006c..82343513f 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -29,8 +29,8 @@ from llama_stack.apis.inference import ( ToolConfig, ) from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.prompt_adapter import content_has_media diff --git a/llama_stack/providers/remote/inference/nvidia/openai_utils.py b/llama_stack/providers/remote/inference/nvidia/openai_utils.py index 0a62a2ab4..c757c562c 100644 --- a/llama_stack/providers/remote/inference/nvidia/openai_utils.py +++ b/llama_stack/providers/remote/inference/nvidia/openai_utils.py @@ -22,17 +22,35 @@ from llama_models.llama3.api.datatypes import ( from openai import AsyncStream from openai.types.chat import ( ChatCompletionAssistantMessageParam as OpenAIChatCompletionAssistantMessage, +) +from openai.types.chat import ( ChatCompletionChunk as OpenAIChatCompletionChunk, +) +from openai.types.chat import ( ChatCompletionContentPartImageParam as OpenAIChatCompletionContentPartImageParam, +) +from openai.types.chat import ( ChatCompletionContentPartParam as OpenAIChatCompletionContentPartParam, +) +from openai.types.chat import ( ChatCompletionMessageParam as OpenAIChatCompletionMessage, +) +from openai.types.chat import ( ChatCompletionMessageToolCallParam as OpenAIChatCompletionMessageToolCall, +) +from openai.types.chat import ( ChatCompletionSystemMessageParam as OpenAIChatCompletionSystemMessage, +) +from openai.types.chat import ( ChatCompletionToolMessageParam as OpenAIChatCompletionToolMessage, +) +from openai.types.chat import ( ChatCompletionUserMessageParam as OpenAIChatCompletionUserMessage, ) from openai.types.chat.chat_completion import ( Choice as OpenAIChoice, +) +from openai.types.chat.chat_completion import ( ChoiceLogprobs as OpenAIChoiceLogprobs, # same as chat_completion_chunk ChoiceLogprobs ) from openai.types.chat.chat_completion_content_part_image_param import ( @@ -69,7 +87,6 @@ from llama_stack.apis.inference import ( ToolResponseMessage, UserMessage, ) - from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index f056b9ab6..a5a4d48ab 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -8,7 +8,6 @@ from typing import Any, Dict from pydantic import BaseModel - DEFAULT_OLLAMA_URL = "http://localhost:11434" diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 05a5d2d7a..1c12d0d91 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -36,14 +36,14 @@ from llama_stack.apis.inference import ( from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( + ModelRegistryHelper, build_model_alias, build_model_alias_with_just_provider_model_id, - ModelRegistryHelper, ) from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, diff --git a/llama_stack/providers/remote/inference/runpod/runpod.py b/llama_stack/providers/remote/inference/runpod/runpod.py index c7b20b9a1..a3c615418 100644 --- a/llama_stack/providers/remote/inference/runpod/runpod.py +++ b/llama_stack/providers/remote/inference/runpod/runpod.py @@ -8,14 +8,12 @@ from typing import AsyncGenerator from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import Message from llama_models.llama3.api.tokenizer import Tokenizer - from openai import OpenAI from llama_stack.apis.inference import * # noqa: F403 # from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper - from llama_stack.providers.utils.inference.openai_compat import ( get_sampling_options, process_chat_completion_response, diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py index 18a78e69c..3546ee977 100644 --- a/llama_stack/providers/remote/inference/sambanova/sambanova.py +++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py @@ -24,8 +24,8 @@ from llama_stack.apis.common.content_types import ( ) from llama_stack.apis.inference import * # noqa: F403 from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( process_chat_completion_stream_response, diff --git a/llama_stack/providers/remote/inference/sample/sample.py b/llama_stack/providers/remote/inference/sample/sample.py index 51ce879eb..106381618 100644 --- a/llama_stack/providers/remote/inference/sample/sample.py +++ b/llama_stack/providers/remote/inference/sample/sample.py @@ -6,6 +6,7 @@ from llama_stack.apis.inference import Inference from llama_stack.apis.models import Model + from .config import SampleConfig diff --git a/llama_stack/providers/remote/inference/tgi/tgi.py b/llama_stack/providers/remote/inference/tgi/tgi.py index 97a6621fb..72eaa6c31 100644 --- a/llama_stack/providers/remote/inference/tgi/tgi.py +++ b/llama_stack/providers/remote/inference/tgi/tgi.py @@ -33,13 +33,13 @@ from llama_stack.apis.inference import ( from llama_stack.apis.models import Model from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( - get_sampling_options, OpenAICompatCompletionChoice, OpenAICompatCompletionResponse, + get_sampling_options, process_chat_completion_response, process_chat_completion_stream_response, process_completion_response, diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index a165b01d9..916e64ad4 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -30,8 +30,8 @@ from llama_stack.apis.inference import ( ) from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( convert_message_to_openai_dict, diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 3574768b5..8f9cf68a8 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -13,10 +13,14 @@ from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.sku_list import all_registered_models from openai import OpenAI -from llama_stack.apis.common.content_types import InterleavedContent, ToolCallDelta, ToolCallParseStatus, TextDelta +from llama_stack.apis.common.content_types import InterleavedContent, TextDelta, ToolCallDelta, ToolCallParseStatus from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, + ChatCompletionResponseEvent, + ChatCompletionResponseEventType, + ChatCompletionResponseStreamChunk, + CompletionMessage, CompletionRequest, CompletionResponse, CompletionResponseStreamChunk, @@ -31,26 +35,22 @@ from llama_stack.apis.inference import ( ToolConfig, ToolDefinition, ToolPromptFormat, - CompletionMessage, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - ChatCompletionResponseEvent, ) from llama_stack.apis.models import Model, ModelType from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.utils.inference.model_registry import ( - build_model_alias, ModelRegistryHelper, + build_model_alias, ) from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict, - get_sampling_options, - process_completion_response, - process_completion_stream_response, OpenAICompatCompletionResponse, UnparseableToolCall, + convert_message_to_openai_dict, convert_tool_call, + get_sampling_options, process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, ) from llama_stack.providers.utils.inference.prompt_adapter import ( completion_request_to_prompt, diff --git a/llama_stack/providers/remote/safety/bedrock/bedrock.py b/llama_stack/providers/remote/safety/bedrock/bedrock.py index b9d9b9825..c8cd129f2 100644 --- a/llama_stack/providers/remote/safety/bedrock/bedrock.py +++ b/llama_stack/providers/remote/safety/bedrock/bedrock.py @@ -6,11 +6,9 @@ import json import logging - from typing import Any, Dict, List from llama_stack.apis.inference import Message - from llama_stack.apis.safety import ( RunShieldResponse, Safety, @@ -23,7 +21,6 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client from .config import BedrockSafetyConfig - logger = logging.getLogger(__name__) diff --git a/llama_stack/providers/remote/safety/sample/sample.py b/llama_stack/providers/remote/safety/sample/sample.py index 180e6c3b5..7645c69e9 100644 --- a/llama_stack/providers/remote/safety/sample/sample.py +++ b/llama_stack/providers/remote/safety/sample/sample.py @@ -6,6 +6,7 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.shields import Shield + from .config import SampleConfig diff --git a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py index 3b05f5632..2ddf7b4fe 100644 --- a/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py +++ b/llama_stack/providers/remote/tool_runtime/model_context_protocol/__init__.py @@ -7,7 +7,6 @@ from pydantic import BaseModel from .config import ModelContextProtocolConfig - from .model_context_protocol import ModelContextProtocolToolRuntimeImpl diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 3ebdd089b..47ef30b5a 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -21,6 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) + from .config import ChromaRemoteImplConfig log = logging.getLogger(__name__) diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index eb1c9aab1..693aacd76 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -10,15 +10,13 @@ from typing import Any, Dict, List, Optional, Tuple import psycopg2 from numpy.typing import NDArray from psycopg2 import sql -from psycopg2.extras import execute_values, Json - +from psycopg2.extras import Json, execute_values from pydantic import BaseModel, TypeAdapter from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate - from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index e7ad136eb..b2eae3dad 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -20,6 +20,7 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) + from .config import QdrantConfig log = logging.getLogger(__name__) diff --git a/llama_stack/providers/remote/vector_io/sample/sample.py b/llama_stack/providers/remote/vector_io/sample/sample.py index e311be39d..b0ba50315 100644 --- a/llama_stack/providers/remote/vector_io/sample/sample.py +++ b/llama_stack/providers/remote/vector_io/sample/sample.py @@ -6,6 +6,7 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import VectorIO + from .config import SampleConfig diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index c57b57609..c4d3c39ac 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import json import logging - from typing import Any, Dict, List, Optional import weaviate diff --git a/llama_stack/providers/tests/agents/conftest.py b/llama_stack/providers/tests/agents/conftest.py index 5759b77c5..3a6ce278a 100644 --- a/llama_stack/providers/tests/agents/conftest.py +++ b/llama_stack/providers/tests/agents/conftest.py @@ -13,7 +13,6 @@ from ..conftest import ( ) from ..inference.fixtures import INFERENCE_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES, safety_model_from_shield - from ..tools.fixtures import TOOL_RUNTIME_FIXTURES from ..vector_io.fixtures import VECTOR_IO_FIXTURES from .fixtures import AGENTS_FIXTURES diff --git a/llama_stack/providers/tests/agents/test_agents.py b/llama_stack/providers/tests/agents/test_agents.py index 37d0c04b5..45b276cc3 100644 --- a/llama_stack/providers/tests/agents/test_agents.py +++ b/llama_stack/providers/tests/agents/test_agents.py @@ -23,7 +23,6 @@ from llama_stack.apis.agents import ( ToolExecutionStep, Turn, ) - from llama_stack.apis.inference import CompletionMessage, UserMessage from llama_stack.apis.safety import ViolationLevel from llama_stack.providers.datatypes import Api diff --git a/llama_stack/providers/tests/agents/test_persistence.py b/llama_stack/providers/tests/agents/test_persistence.py index a1d69c9ca..f02279e8d 100644 --- a/llama_stack/providers/tests/agents/test_persistence.py +++ b/llama_stack/providers/tests/agents/test_persistence.py @@ -13,7 +13,6 @@ from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from .fixtures import pick_inference_model - from .utils import create_agent_session diff --git a/llama_stack/providers/tests/conftest.py b/llama_stack/providers/tests/conftest.py index cf88e8fe8..d3e715b7e 100644 --- a/llama_stack/providers/tests/conftest.py +++ b/llama_stack/providers/tests/conftest.py @@ -6,13 +6,11 @@ import os from collections import defaultdict - from pathlib import Path from typing import Any, Dict, List, Optional import pytest import yaml - from dotenv import load_dotenv from pydantic import BaseModel, Field from termcolor import colored diff --git a/llama_stack/providers/tests/datasetio/fixtures.py b/llama_stack/providers/tests/datasetio/fixtures.py index d288198ca..27aedb645 100644 --- a/llama_stack/providers/tests/datasetio/fixtures.py +++ b/llama_stack/providers/tests/datasetio/fixtures.py @@ -8,7 +8,6 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, Provider - from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture diff --git a/llama_stack/providers/tests/eval/conftest.py b/llama_stack/providers/tests/eval/conftest.py index 84eae2efa..c1da6ba42 100644 --- a/llama_stack/providers/tests/eval/conftest.py +++ b/llama_stack/providers/tests/eval/conftest.py @@ -7,16 +7,14 @@ import pytest from ..agents.fixtures import AGENTS_FIXTURES - from ..conftest import get_provider_fixture_overrides - from ..datasetio.fixtures import DATASETIO_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES from ..safety.fixtures import SAFETY_FIXTURES from ..scoring.fixtures import SCORING_FIXTURES from ..tools.fixtures import TOOL_RUNTIME_FIXTURES -from .fixtures import EVAL_FIXTURES from ..vector_io.fixtures import VECTOR_IO_FIXTURES +from .fixtures import EVAL_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( diff --git a/llama_stack/providers/tests/eval/fixtures.py b/llama_stack/providers/tests/eval/fixtures.py index 009e65fb3..c6d15bbf5 100644 --- a/llama_stack/providers/tests/eval/fixtures.py +++ b/llama_stack/providers/tests/eval/fixtures.py @@ -8,8 +8,8 @@ import pytest import pytest_asyncio from llama_stack.distribution.datatypes import Api, ModelInput, Provider - from llama_stack.providers.tests.resolver import construct_stack_for_test + from ..conftest import ProviderFixture, remote_stack_fixture diff --git a/llama_stack/providers/tests/eval/test_eval.py b/llama_stack/providers/tests/eval/test_eval.py index 40835bf53..ec3d08728 100644 --- a/llama_stack/providers/tests/eval/test_eval.py +++ b/llama_stack/providers/tests/eval/test_eval.py @@ -9,7 +9,6 @@ import pytest from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.type_system import ChatCompletionInputType, StringType - from llama_stack.apis.eval.eval import ( AppEvalTaskConfig, BenchmarkEvalTaskConfig, @@ -19,6 +18,7 @@ from llama_stack.apis.inference import SamplingParams from llama_stack.apis.scoring_functions import LLMAsJudgeScoringFnParams from llama_stack.distribution.datatypes import Api from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset + from .constants import JUDGE_PROMPT # How to run this test: diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index b33a217bb..2a782befc 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -11,13 +11,11 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider - from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) from llama_stack.providers.inline.inference.vllm import VLLMConfig from llama_stack.providers.remote.inference.bedrock import BedrockConfig - from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.groq import GroqConfig diff --git a/llama_stack/providers/tests/inference/groq/test_groq_utils.py b/llama_stack/providers/tests/inference/groq/test_groq_utils.py index a28dd308e..3eba991c1 100644 --- a/llama_stack/providers/tests/inference/groq/test_groq_utils.py +++ b/llama_stack/providers/tests/inference/groq/test_groq_utils.py @@ -10,11 +10,13 @@ import pytest from groq.types.chat.chat_completion import ChatCompletion, Choice from groq.types.chat.chat_completion_chunk import ( ChatCompletionChunk, - Choice as StreamChoice, ChoiceDelta, ChoiceDeltaToolCall, ChoiceDeltaToolCallFunction, ) +from groq.types.chat.chat_completion_chunk import ( + Choice as StreamChoice, +) from groq.types.chat.chat_completion_message import ChatCompletionMessage from groq.types.chat.chat_completion_message_tool_call import ( ChatCompletionMessageToolCall, @@ -23,6 +25,7 @@ from groq.types.chat.chat_completion_message_tool_call import ( from groq.types.shared.function_definition import FunctionDefinition from llama_models.datatypes import GreedySamplingStrategy, TopPSamplingStrategy from llama_models.llama3.api.datatypes import ToolParamDefinition + from llama_stack.apis.common.content_types import ToolCallParseStatus from llama_stack.apis.inference import ( ChatCompletionRequest, diff --git a/llama_stack/providers/tests/inference/groq/test_init.py b/llama_stack/providers/tests/inference/groq/test_init.py index d23af5934..4cdd3bfd5 100644 --- a/llama_stack/providers/tests/inference/groq/test_init.py +++ b/llama_stack/providers/tests/inference/groq/test_init.py @@ -5,11 +5,11 @@ # the root directory of this source tree. import pytest + from llama_stack.apis.inference import Inference from llama_stack.providers.remote.inference.groq import get_adapter_impl from llama_stack.providers.remote.inference.groq.config import GroqConfig from llama_stack.providers.remote.inference.groq.groq import GroqInferenceAdapter - from llama_stack.providers.remote.inference.ollama import OllamaImplConfig diff --git a/llama_stack/providers/tests/inference/test_model_registration.py b/llama_stack/providers/tests/inference/test_model_registration.py index 664564d22..7c41b07ef 100644 --- a/llama_stack/providers/tests/inference/test_model_registration.py +++ b/llama_stack/providers/tests/inference/test_model_registration.py @@ -8,7 +8,6 @@ from unittest.mock import AsyncMock, patch import pytest - # How to run this test: # # torchrun $CONDA_PREFIX/bin/pytest -v -s -k "meta_reference" --inference-model="Llama3.1-8B-Instruct" diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 6a7259123..14ed2fc4b 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -6,7 +6,6 @@ import pytest - from llama_models.llama3.api.datatypes import ( SamplingParams, StopReason, @@ -15,7 +14,6 @@ from llama_models.llama3.api.datatypes import ( ToolParamDefinition, ToolPromptFormat, ) - from pydantic import BaseModel, ValidationError from llama_stack.apis.common.content_types import ToolCallParseStatus @@ -35,7 +33,6 @@ from llama_stack.apis.models import ListModelsResponse, Model from .utils import group_chunks - # How to run this test: # # pytest -v -s llama_stack/providers/tests/inference/test_text_inference.py diff --git a/llama_stack/providers/tests/post_training/conftest.py b/llama_stack/providers/tests/post_training/conftest.py index 3cd60e53a..b6d95444b 100644 --- a/llama_stack/providers/tests/post_training/conftest.py +++ b/llama_stack/providers/tests/post_training/conftest.py @@ -7,9 +7,7 @@ import pytest from ..conftest import get_provider_fixture_overrides - from ..datasetio.fixtures import DATASETIO_FIXTURES - from .fixtures import POST_TRAINING_FIXTURES DEFAULT_PROVIDER_COMBINATIONS = [ diff --git a/llama_stack/providers/tests/post_training/fixtures.py b/llama_stack/providers/tests/post_training/fixtures.py index fd8a9e4f6..7c3ff3ddb 100644 --- a/llama_stack/providers/tests/post_training/fixtures.py +++ b/llama_stack/providers/tests/post_training/fixtures.py @@ -8,13 +8,10 @@ import pytest import pytest_asyncio from llama_stack.apis.common.content_types import URL - from llama_stack.apis.common.type_system import StringType from llama_stack.apis.datasets import DatasetInput from llama_stack.apis.models import ModelInput - from llama_stack.distribution.datatypes import Api, Provider - from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture diff --git a/llama_stack/providers/tests/report.py b/llama_stack/providers/tests/report.py index b7a238908..3901dc2e3 100644 --- a/llama_stack/providers/tests/report.py +++ b/llama_stack/providers/tests/report.py @@ -12,10 +12,8 @@ import pytest from llama_models.datatypes import CoreModelId from llama_models.sku_list import all_registered_models from pytest import ExitCode - from pytest_html.basereport import _process_outcome - INFERENCE_APIS = ["chat_completion"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] SUPPORTED_MODELS = { diff --git a/llama_stack/providers/tests/safety/conftest.py b/llama_stack/providers/tests/safety/conftest.py index 10a8517fc..3e46f0d50 100644 --- a/llama_stack/providers/tests/safety/conftest.py +++ b/llama_stack/providers/tests/safety/conftest.py @@ -7,11 +7,9 @@ import pytest from ..conftest import get_provider_fixture_overrides - from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import SAFETY_FIXTURES - DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { diff --git a/llama_stack/providers/tests/safety/fixtures.py b/llama_stack/providers/tests/safety/fixtures.py index 32883bfab..a0c00ee7c 100644 --- a/llama_stack/providers/tests/safety/fixtures.py +++ b/llama_stack/providers/tests/safety/fixtures.py @@ -8,14 +8,11 @@ import pytest import pytest_asyncio from llama_stack.apis.models import ModelInput - from llama_stack.apis.shields import ShieldInput - from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.safety.llama_guard import LlamaGuardConfig from llama_stack.providers.inline.safety.prompt_guard import PromptGuardConfig from llama_stack.providers.remote.safety.bedrock import BedrockSafetyConfig - from llama_stack.providers.tests.resolver import construct_stack_for_test from ..conftest import ProviderFixture, remote_stack_fixture diff --git a/llama_stack/providers/tests/scoring/conftest.py b/llama_stack/providers/tests/scoring/conftest.py index 450f65695..9278d3c2d 100644 --- a/llama_stack/providers/tests/scoring/conftest.py +++ b/llama_stack/providers/tests/scoring/conftest.py @@ -7,7 +7,6 @@ import pytest from ..conftest import get_provider_fixture_overrides - from ..datasetio.fixtures import DATASETIO_FIXTURES from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import SCORING_FIXTURES diff --git a/llama_stack/providers/tests/scoring/fixtures.py b/llama_stack/providers/tests/scoring/fixtures.py index 2cf32b1e2..09f31cbc2 100644 --- a/llama_stack/providers/tests/scoring/fixtures.py +++ b/llama_stack/providers/tests/scoring/fixtures.py @@ -8,10 +8,10 @@ import pytest import pytest_asyncio from llama_stack.apis.models import ModelInput - from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.scoring.braintrust import BraintrustScoringConfig from llama_stack.providers.tests.resolver import construct_stack_for_test + from ..conftest import ProviderFixture, remote_stack_fixture from ..env import get_env_or_fail diff --git a/llama_stack/providers/tests/vector_io/conftest.py b/llama_stack/providers/tests/vector_io/conftest.py index 3a02ac712..3da64ff2e 100644 --- a/llama_stack/providers/tests/vector_io/conftest.py +++ b/llama_stack/providers/tests/vector_io/conftest.py @@ -11,11 +11,9 @@ from ..conftest import ( get_provider_fixture_overrides_from_test_config, get_test_config_for_api, ) - from ..inference.fixtures import INFERENCE_FIXTURES from .fixtures import VECTOR_IO_FIXTURES - DEFAULT_PROVIDER_COMBINATIONS = [ pytest.param( { diff --git a/llama_stack/providers/tests/vector_io/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py index 54a76141f..60d174d9e 100644 --- a/llama_stack/providers/tests/vector_io/fixtures.py +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -12,7 +12,6 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider - from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig diff --git a/llama_stack/providers/tests/vector_io/test_vector_io.py b/llama_stack/providers/tests/vector_io/test_vector_io.py index 81b080f63..77bc24a21 100644 --- a/llama_stack/providers/tests/vector_io/test_vector_io.py +++ b/llama_stack/providers/tests/vector_io/test_vector_io.py @@ -9,10 +9,8 @@ import uuid import pytest from llama_stack.apis.tools import RAGDocument - from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB from llama_stack.apis.vector_io import QueryChunksResponse - from llama_stack.providers.utils.memory.vector_store import make_overlapped_chunks # How to run this test: diff --git a/llama_stack/providers/tests/vector_io/test_vector_store.py b/llama_stack/providers/tests/vector_io/test_vector_store.py index 2a41a8982..e0d340657 100644 --- a/llama_stack/providers/tests/vector_io/test_vector_store.py +++ b/llama_stack/providers/tests/vector_io/test_vector_store.py @@ -12,8 +12,7 @@ from pathlib import Path import pytest from llama_stack.apis.tools import RAGDocument - -from llama_stack.providers.utils.memory.vector_store import content_from_doc, URL +from llama_stack.providers.utils.memory.vector_store import URL, content_from_doc DUMMY_PDF_PATH = Path(os.path.abspath(__file__)).parent / "fixtures" / "dummy.pdf" diff --git a/llama_stack/providers/utils/common/data_schema_validator.py b/llama_stack/providers/utils/common/data_schema_validator.py index 8b5618950..3d14c4148 100644 --- a/llama_stack/providers/utils/common/data_schema_validator.py +++ b/llama_stack/providers/utils/common/data_schema_validator.py @@ -12,7 +12,6 @@ from llama_stack.apis.common.type_system import ( CompletionInputType, StringType, ) - from llama_stack.distribution.datatypes import Api diff --git a/llama_stack/providers/utils/datasetio/url_utils.py b/llama_stack/providers/utils/datasetio/url_utils.py index da1e84d4d..f54cb55eb 100644 --- a/llama_stack/providers/utils/datasetio/url_utils.py +++ b/llama_stack/providers/utils/datasetio/url_utils.py @@ -11,7 +11,6 @@ from urllib.parse import unquote import pandas from llama_stack.apis.common.content_types import URL - from llama_stack.providers.utils.memory.vector_store import parse_data_url diff --git a/llama_stack/providers/utils/inference/model_registry.py b/llama_stack/providers/utils/inference/model_registry.py index dea951395..9345da949 100644 --- a/llama_stack/providers/utils/inference/model_registry.py +++ b/llama_stack/providers/utils/inference/model_registry.py @@ -11,7 +11,6 @@ from llama_models.sku_list import all_registered_models from llama_stack.apis.models.models import ModelType from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate - from llama_stack.providers.utils.inference import ( ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ) diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index 7480ff2c7..00e291e8f 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -13,7 +13,6 @@ from llama_models.datatypes import ( TopKSamplingStrategy, TopPSamplingStrategy, ) - from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import StopReason, ToolCall from openai.types.chat import ChatCompletionMessageToolCall @@ -26,7 +25,6 @@ from llama_stack.apis.common.content_types import ( ToolCallDelta, ToolCallParseStatus, ) - from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -39,7 +37,6 @@ from llama_stack.apis.inference import ( Message, TokenLogProbs, ) - from llama_stack.providers.utils.inference.prompt_adapter import ( convert_image_content_to_url, ) diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 57875e64b..15149e059 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -13,7 +13,7 @@ import re from typing import List, Optional, Tuple, Union import httpx -from llama_models.datatypes import is_multimodal, ModelFamily +from llama_models.datatypes import ModelFamily, is_multimodal from llama_models.llama3.api.chat_format import ChatFormat from llama_models.llama3.api.datatypes import ( RawContent, @@ -47,9 +47,9 @@ from llama_stack.apis.inference import ( ResponseFormat, ResponseFormatType, SystemMessage, + SystemMessageBehavior, ToolChoice, UserMessage, - SystemMessageBehavior, ) from llama_stack.providers.utils.inference import supported_inference_models diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index e7a33503b..bc0488aac 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os - from datetime import datetime from typing import List, Optional diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 310db18b0..1ac1cf8d5 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -15,13 +15,14 @@ from urllib.parse import unquote import chardet import httpx import numpy as np - from llama_models.llama3.api.tokenizer import Tokenizer +from numpy.typing import NDArray +from pypdf import PdfReader from llama_stack.apis.common.content_types import ( + URL, InterleavedContent, TextContentItem, - URL, ) from llama_stack.apis.tools import RAGDocument from llama_stack.apis.vector_dbs import VectorDB @@ -30,9 +31,6 @@ from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -from numpy.typing import NDArray - -from pypdf import PdfReader log = logging.getLogger(__name__) diff --git a/llama_stack/scripts/distro_codegen.py b/llama_stack/scripts/distro_codegen.py index c73c15d41..825a039ef 100644 --- a/llama_stack/scripts/distro_codegen.py +++ b/llama_stack/scripts/distro_codegen.py @@ -16,11 +16,10 @@ from typing import Iterator from rich.progress import Progress, SpinnerColumn, TextColumn from llama_stack.distribution.build import ( - get_provider_dependencies, SERVER_DEPENDENCIES, + get_provider_dependencies, ) - REPO_ROOT = Path(__file__).parent.parent.parent diff --git a/llama_stack/templates/dell/dell.py b/llama_stack/templates/dell/dell.py index 5781da7f4..116fbd285 100644 --- a/llama_stack/templates/dell/dell.py +++ b/llama_stack/templates/dell/dell.py @@ -15,7 +15,6 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) - from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/sambanova/sambanova.py b/llama_stack/templates/sambanova/sambanova.py index 70b54b010..6d7477c8e 100644 --- a/llama_stack/templates/sambanova/sambanova.py +++ b/llama_stack/templates/sambanova/sambanova.py @@ -16,7 +16,6 @@ from llama_stack.distribution.datatypes import ( ) from llama_stack.providers.remote.inference.sambanova import SambaNovaImplConfig from llama_stack.providers.remote.inference.sambanova.sambanova import MODEL_ALIASES - from llama_stack.templates.template import DistributionTemplate, RunConfigSettings diff --git a/llama_stack/templates/template.py b/llama_stack/templates/template.py index 04a09741c..cb5b07be3 100644 --- a/llama_stack/templates/template.py +++ b/llama_stack/templates/template.py @@ -9,6 +9,7 @@ from typing import Dict, List, Literal, Optional, Tuple import jinja2 import yaml +from pydantic import BaseModel, Field from llama_stack.apis.models.models import ModelType from llama_stack.distribution.datatypes import ( @@ -24,7 +25,6 @@ from llama_stack.distribution.datatypes import ( from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig -from pydantic import BaseModel, Field class RunConfigSettings(BaseModel): diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index d14a7003f..f42341f72 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -13,12 +13,14 @@ from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.types import ToolResponseMessage -from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.shared.completion_message import CompletionMessage +from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.tool_def_param import Parameter -from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig, ToolChoice + +from llama_stack.apis.agents.agents import AgentConfig as Server__AgentConfig +from llama_stack.apis.agents.agents import ToolChoice class TestClientTool(ClientTool): diff --git a/tests/client-sdk/conftest.py b/tests/client-sdk/conftest.py index 8c44242fe..b397f7ab3 100644 --- a/tests/client-sdk/conftest.py +++ b/tests/client-sdk/conftest.py @@ -6,11 +6,11 @@ import os import pytest +from llama_stack_client import LlamaStackClient +from report import Report from llama_stack import LlamaStackAsLibraryClient from llama_stack.providers.tests.env import get_env_or_fail -from llama_stack_client import LlamaStackClient -from report import Report def pytest_configure(config): diff --git a/tests/client-sdk/report.py b/tests/client-sdk/report.py index 5e8203ecb..543562541 100644 --- a/tests/client-sdk/report.py +++ b/tests/client-sdk/report.py @@ -22,15 +22,13 @@ from llama_models.sku_list import ( llama3_instruct_models, safety_models, ) +from metadata import API_MAPS +from pytest import CollectReport +from termcolor import cprint from llama_stack.providers.datatypes import Api from llama_stack.providers.tests.env import get_env_or_fail -from metadata import API_MAPS - -from pytest import CollectReport -from termcolor import cprint - def featured_models(): models = [ diff --git a/tests/client-sdk/tool_runtime/test_rag_tool.py b/tests/client-sdk/tool_runtime/test_rag_tool.py index f776bd0a9..40940f1ef 100644 --- a/tests/client-sdk/tool_runtime/test_rag_tool.py +++ b/tests/client-sdk/tool_runtime/test_rag_tool.py @@ -7,7 +7,6 @@ import random import pytest - from llama_stack_client.types import Document From 8ff27b58fa85f240cac7f6c70e9953096f098865 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 13 Feb 2025 13:15:49 -0500 Subject: [PATCH 26/27] chore: Consistent naming for VectorIO providers (#1023) # What does this PR do? This changes all VectorIO providers classes to follow the pattern `VectorIOConfig` and `VectorIOAdapter`. All API endpoints for VectorIOs are currently consistent with `/vector-io`. Note that API endpoint for VectorDB stay unchanged as `/vector-dbs`. ## Test Plan I don't have a way to test all providers. This is a simple renaming so things should work as expected. --------- Signed-off-by: Yuan Tang --- llama_stack/apis/telemetry/telemetry.py | 2 +- .../inline/vector_io/chroma/__init__.py | 4 ++-- .../inline/vector_io/chroma/config.py | 2 +- .../inline/vector_io/faiss/__init__.py | 10 +++++----- .../inline/vector_io/faiss/config.py | 2 +- .../providers/inline/vector_io/faiss/faiss.py | 6 +++--- .../remote/vector_io/chroma/__init__.py | 4 ++-- .../remote/vector_io/chroma/chroma.py | 7 +++---- .../remote/vector_io/chroma/config.py | 2 +- .../remote/vector_io/pgvector/__init__.py | 8 ++++---- .../remote/vector_io/pgvector/config.py | 2 +- .../remote/vector_io/pgvector/pgvector.py | 6 +++--- .../remote/vector_io/qdrant/__init__.py | 8 ++++---- .../remote/vector_io/qdrant/config.py | 2 +- .../remote/vector_io/qdrant/qdrant.py | 6 +++--- .../remote/vector_io/sample/__init__.py | 8 ++++---- .../remote/vector_io/sample/config.py | 2 +- .../remote/vector_io/sample/sample.py | 6 +++--- .../remote/vector_io/weaviate/__init__.py | 8 ++++---- .../remote/vector_io/weaviate/config.py | 2 +- .../remote/vector_io/weaviate/weaviate.py | 6 +++--- .../providers/tests/vector_io/fixtures.py | 20 +++++++++---------- llama_stack/templates/bedrock/bedrock.py | 4 ++-- llama_stack/templates/cerebras/cerebras.py | 4 ++-- llama_stack/templates/fireworks/fireworks.py | 4 ++-- .../templates/hf-endpoint/hf_endpoint.py | 4 ++-- .../templates/hf-serverless/hf_serverless.py | 4 ++-- .../meta-reference-gpu/meta_reference.py | 4 ++-- .../meta_reference.py | 4 ++-- llama_stack/templates/ollama/ollama.py | 4 ++-- llama_stack/templates/remote-vllm/vllm.py | 4 ++-- llama_stack/templates/tgi/tgi.py | 4 ++-- llama_stack/templates/together/together.py | 4 ++-- llama_stack/templates/vllm-gpu/vllm.py | 4 ++-- 34 files changed, 85 insertions(+), 86 deletions(-) diff --git a/llama_stack/apis/telemetry/telemetry.py b/llama_stack/apis/telemetry/telemetry.py index 5622aaeac..63ae1dc73 100644 --- a/llama_stack/apis/telemetry/telemetry.py +++ b/llama_stack/apis/telemetry/telemetry.py @@ -13,8 +13,8 @@ from typing import ( Literal, Optional, Protocol, - runtime_checkable, Union, + runtime_checkable, ) from llama_models.llama3.api.datatypes import Primitive diff --git a/llama_stack/providers/inline/vector_io/chroma/__init__.py b/llama_stack/providers/inline/vector_io/chroma/__init__.py index 56a4ac21c..abaf01097 100644 --- a/llama_stack/providers/inline/vector_io/chroma/__init__.py +++ b/llama_stack/providers/inline/vector_io/chroma/__init__.py @@ -8,10 +8,10 @@ from typing import Dict from llama_stack.providers.datatypes import Api, ProviderSpec -from .config import ChromaInlineImplConfig +from .config import ChromaVectorIOConfig -async def get_provider_impl(config: ChromaInlineImplConfig, deps: Dict[Api, ProviderSpec]): +async def get_provider_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]): from llama_stack.providers.remote.vector_io.chroma.chroma import ( ChromaVectorIOAdapter, ) diff --git a/llama_stack/providers/inline/vector_io/chroma/config.py b/llama_stack/providers/inline/vector_io/chroma/config.py index efbd77faf..a1fb60fa6 100644 --- a/llama_stack/providers/inline/vector_io/chroma/config.py +++ b/llama_stack/providers/inline/vector_io/chroma/config.py @@ -9,7 +9,7 @@ from typing import Any, Dict from pydantic import BaseModel -class ChromaInlineImplConfig(BaseModel): +class ChromaVectorIOConfig(BaseModel): db_path: str @classmethod diff --git a/llama_stack/providers/inline/vector_io/faiss/__init__.py b/llama_stack/providers/inline/vector_io/faiss/__init__.py index 8c075a0f8..f23e1fa4f 100644 --- a/llama_stack/providers/inline/vector_io/faiss/__init__.py +++ b/llama_stack/providers/inline/vector_io/faiss/__init__.py @@ -8,14 +8,14 @@ from typing import Dict from llama_stack.providers.datatypes import Api, ProviderSpec -from .config import FaissImplConfig +from .config import FaissVectorIOConfig -async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]): - from .faiss import FaissVectorIOImpl +async def get_provider_impl(config: FaissVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .faiss import FaissVectorIOAdapter - assert isinstance(config, FaissImplConfig), f"Unexpected config type: {type(config)}" + assert isinstance(config, FaissVectorIOConfig), f"Unexpected config type: {type(config)}" - impl = FaissVectorIOImpl(config, deps[Api.inference]) + impl = FaissVectorIOAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/vector_io/faiss/config.py b/llama_stack/providers/inline/vector_io/faiss/config.py index d82104477..ae859842d 100644 --- a/llama_stack/providers/inline/vector_io/faiss/config.py +++ b/llama_stack/providers/inline/vector_io/faiss/config.py @@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import ( @json_schema_type -class FaissImplConfig(BaseModel): +class FaissVectorIOConfig(BaseModel): kvstore: KVStoreConfig @classmethod diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 565afdcf6..b52fb074c 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -24,7 +24,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -from .config import FaissImplConfig +from .config import FaissVectorIOConfig logger = logging.getLogger(__name__) @@ -112,8 +112,8 @@ class FaissIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) -class FaissVectorIOImpl(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None: +class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: FaissVectorIOConfig, inference_api: Api.inference) -> None: self.config = config self.inference_api = inference_api self.cache = {} diff --git a/llama_stack/providers/remote/vector_io/chroma/__init__.py b/llama_stack/providers/remote/vector_io/chroma/__init__.py index 9990120f5..8646b04d6 100644 --- a/llama_stack/providers/remote/vector_io/chroma/__init__.py +++ b/llama_stack/providers/remote/vector_io/chroma/__init__.py @@ -8,10 +8,10 @@ from typing import Dict from llama_stack.providers.datatypes import Api, ProviderSpec -from .config import ChromaRemoteImplConfig +from .config import ChromaVectorIOConfig -async def get_adapter_impl(config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]): +async def get_adapter_impl(config: ChromaVectorIOConfig, deps: Dict[Api, ProviderSpec]): from .chroma import ChromaVectorIOAdapter impl = ChromaVectorIOAdapter(config, deps[Api.inference]) diff --git a/llama_stack/providers/remote/vector_io/chroma/chroma.py b/llama_stack/providers/remote/vector_io/chroma/chroma.py index 47ef30b5a..f894a8e65 100644 --- a/llama_stack/providers/remote/vector_io/chroma/chroma.py +++ b/llama_stack/providers/remote/vector_io/chroma/chroma.py @@ -16,13 +16,12 @@ from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate -from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) -from .config import ChromaRemoteImplConfig +from .config import ChromaVectorIOConfig log = logging.getLogger(__name__) @@ -89,7 +88,7 @@ class ChromaIndex(EmbeddingIndex): class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): def __init__( self, - config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig], + config: Union[ChromaVectorIOConfig, ChromaVectorIOConfig], inference_api: Api.inference, ) -> None: log.info(f"Initializing ChromaVectorIOAdapter with url: {config}") @@ -100,7 +99,7 @@ class ChromaVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): self.cache = {} async def initialize(self) -> None: - if isinstance(self.config, ChromaRemoteImplConfig): + if isinstance(self.config, ChromaVectorIOConfig): log.info(f"Connecting to Chroma server at: {self.config.url}") url = self.config.url.rstrip("/") parsed = urlparse(url) diff --git a/llama_stack/providers/remote/vector_io/chroma/config.py b/llama_stack/providers/remote/vector_io/chroma/config.py index 68ca2c967..cbbfa9de3 100644 --- a/llama_stack/providers/remote/vector_io/chroma/config.py +++ b/llama_stack/providers/remote/vector_io/chroma/config.py @@ -9,7 +9,7 @@ from typing import Any, Dict from pydantic import BaseModel -class ChromaRemoteImplConfig(BaseModel): +class ChromaVectorIOConfig(BaseModel): url: str @classmethod diff --git a/llama_stack/providers/remote/vector_io/pgvector/__init__.py b/llama_stack/providers/remote/vector_io/pgvector/__init__.py index bdca7acb1..089d890b7 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/__init__.py +++ b/llama_stack/providers/remote/vector_io/pgvector/__init__.py @@ -8,12 +8,12 @@ from typing import Dict from llama_stack.providers.datatypes import Api, ProviderSpec -from .config import PGVectorConfig +from .config import PGVectorVectorIOConfig -async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]): - from .pgvector import PGVectorVectorDBAdapter +async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .pgvector import PGVectorVectorIOAdapter - impl = PGVectorVectorDBAdapter(config, deps[Api.inference]) + impl = PGVectorVectorIOAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/pgvector/config.py b/llama_stack/providers/remote/vector_io/pgvector/config.py index 41983e7b2..2a64d7c67 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/config.py +++ b/llama_stack/providers/remote/vector_io/pgvector/config.py @@ -9,7 +9,7 @@ from pydantic import BaseModel, Field @json_schema_type -class PGVectorConfig(BaseModel): +class PGVectorVectorIOConfig(BaseModel): host: str = Field(default="localhost") port: int = Field(default=5432) db: str = Field(default="postgres") diff --git a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py index 693aacd76..269cf554b 100644 --- a/llama_stack/providers/remote/vector_io/pgvector/pgvector.py +++ b/llama_stack/providers/remote/vector_io/pgvector/pgvector.py @@ -22,7 +22,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -from .config import PGVectorConfig +from .config import PGVectorVectorIOConfig log = logging.getLogger(__name__) @@ -121,8 +121,8 @@ class PGVectorIndex(EmbeddingIndex): cur.execute(f"DROP TABLE IF EXISTS {self.table_name}") -class PGVectorVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None: +class PGVectorVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: PGVectorVectorIOConfig, inference_api: Api.inference) -> None: self.config = config self.inference_api = inference_api self.conn = None diff --git a/llama_stack/providers/remote/vector_io/qdrant/__init__.py b/llama_stack/providers/remote/vector_io/qdrant/__init__.py index c584e29ef..f5bb7f84c 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/__init__.py +++ b/llama_stack/providers/remote/vector_io/qdrant/__init__.py @@ -8,12 +8,12 @@ from typing import Dict from llama_stack.providers.datatypes import Api, ProviderSpec -from .config import QdrantConfig +from .config import QdrantVectorIOConfig -async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]): - from .qdrant import QdrantVectorDBAdapter +async def get_adapter_impl(config: QdrantVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .qdrant import QdrantVectorIOAdapter - impl = QdrantVectorDBAdapter(config, deps[Api.inference]) + impl = QdrantVectorIOAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/qdrant/config.py b/llama_stack/providers/remote/vector_io/qdrant/config.py index a6a5a6ff6..613cfa6e4 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/config.py +++ b/llama_stack/providers/remote/vector_io/qdrant/config.py @@ -11,7 +11,7 @@ from pydantic import BaseModel @json_schema_type -class QdrantConfig(BaseModel): +class QdrantVectorIOConfig(BaseModel): location: Optional[str] = None url: Optional[str] = None port: Optional[int] = 6333 diff --git a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py index b2eae3dad..e1091e2cf 100644 --- a/llama_stack/providers/remote/vector_io/qdrant/qdrant.py +++ b/llama_stack/providers/remote/vector_io/qdrant/qdrant.py @@ -21,7 +21,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -from .config import QdrantConfig +from .config import QdrantVectorIOConfig log = logging.getLogger(__name__) CHUNK_ID_KEY = "_chunk_id" @@ -98,8 +98,8 @@ class QdrantIndex(EmbeddingIndex): await self.client.delete_collection(collection_name=self.collection_name) -class QdrantVectorDBAdapter(VectorIO, VectorDBsProtocolPrivate): - def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None: +class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): + def __init__(self, config: QdrantVectorIOConfig, inference_api: Api.inference) -> None: self.config = config self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.cache = {} diff --git a/llama_stack/providers/remote/vector_io/sample/__init__.py b/llama_stack/providers/remote/vector_io/sample/__init__.py index c9accdf62..221f47b1c 100644 --- a/llama_stack/providers/remote/vector_io/sample/__init__.py +++ b/llama_stack/providers/remote/vector_io/sample/__init__.py @@ -6,12 +6,12 @@ from typing import Any -from .config import SampleConfig +from .config import SampleVectorIOConfig -async def get_adapter_impl(config: SampleConfig, _deps) -> Any: - from .sample import SampleMemoryImpl +async def get_adapter_impl(config: SampleVectorIOConfig, _deps) -> Any: + from .sample import SampleVectorIOImpl - impl = SampleMemoryImpl(config) + impl = SampleVectorIOImpl(config) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/sample/config.py b/llama_stack/providers/remote/vector_io/sample/config.py index 4b7404a26..5126e5eff 100644 --- a/llama_stack/providers/remote/vector_io/sample/config.py +++ b/llama_stack/providers/remote/vector_io/sample/config.py @@ -7,6 +7,6 @@ from pydantic import BaseModel -class SampleConfig(BaseModel): +class SampleVectorIOConfig(BaseModel): host: str = "localhost" port: int = 9999 diff --git a/llama_stack/providers/remote/vector_io/sample/sample.py b/llama_stack/providers/remote/vector_io/sample/sample.py index b0ba50315..cb7193cf4 100644 --- a/llama_stack/providers/remote/vector_io/sample/sample.py +++ b/llama_stack/providers/remote/vector_io/sample/sample.py @@ -7,11 +7,11 @@ from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_io import VectorIO -from .config import SampleConfig +from .config import SampleVectorIOConfig -class SampleMemoryImpl(VectorIO): - def __init__(self, config: SampleConfig): +class SampleVectorIOImpl(VectorIO): + def __init__(self, config: SampleVectorIOConfig): self.config = config async def register_vector_db(self, vector_db: VectorDB) -> None: diff --git a/llama_stack/providers/remote/vector_io/weaviate/__init__.py b/llama_stack/providers/remote/vector_io/weaviate/__init__.py index f7120bec0..c93c628d8 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/__init__.py +++ b/llama_stack/providers/remote/vector_io/weaviate/__init__.py @@ -8,12 +8,12 @@ from typing import Dict from llama_stack.providers.datatypes import Api, ProviderSpec -from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 +from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig # noqa: F401 -async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]): - from .weaviate import WeaviateMemoryAdapter +async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: Dict[Api, ProviderSpec]): + from .weaviate import WeaviateVectorIOAdapter - impl = WeaviateMemoryAdapter(config, deps[Api.inference]) + impl = WeaviateVectorIOAdapter(config, deps[Api.inference]) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py index d0811acb4..6aad9a5a6 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/config.py +++ b/llama_stack/providers/remote/vector_io/weaviate/config.py @@ -12,5 +12,5 @@ class WeaviateRequestProviderData(BaseModel): weaviate_cluster_url: str -class WeaviateConfig(BaseModel): +class WeaviateVectorIOConfig(BaseModel): pass diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index c4d3c39ac..52aa2f3a3 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -23,7 +23,7 @@ from llama_stack.providers.utils.memory.vector_store import ( VectorDBWithIndex, ) -from .config import WeaviateConfig, WeaviateRequestProviderData +from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig log = logging.getLogger(__name__) @@ -85,12 +85,12 @@ class WeaviateIndex(EmbeddingIndex): collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) -class WeaviateMemoryAdapter( +class WeaviateVectorIOAdapter( VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate, ): - def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None: + def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None: self.config = config self.inference_api = inference_api self.client_cache = {} diff --git a/llama_stack/providers/tests/vector_io/fixtures.py b/llama_stack/providers/tests/vector_io/fixtures.py index 60d174d9e..30a2679d7 100644 --- a/llama_stack/providers/tests/vector_io/fixtures.py +++ b/llama_stack/providers/tests/vector_io/fixtures.py @@ -12,12 +12,12 @@ import pytest_asyncio from llama_stack.apis.models import ModelInput, ModelType from llama_stack.distribution.datatypes import Api, Provider -from llama_stack.providers.inline.vector_io.chroma import ChromaInlineImplConfig -from llama_stack.providers.inline.vector_io.faiss import FaissImplConfig +from llama_stack.providers.inline.vector_io.chroma import ChromaVectorIOConfig as InlineChromaVectorIOConfig +from llama_stack.providers.inline.vector_io.faiss import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec import SQLiteVectorIOConfig -from llama_stack.providers.remote.vector_io.chroma import ChromaRemoteImplConfig -from llama_stack.providers.remote.vector_io.pgvector import PGVectorConfig -from llama_stack.providers.remote.vector_io.weaviate import WeaviateConfig +from llama_stack.providers.remote.vector_io.chroma import ChromaVectorIOConfig +from llama_stack.providers.remote.vector_io.pgvector import PGVectorVectorIOConfig +from llama_stack.providers.remote.vector_io.weaviate import WeaviateVectorIOConfig from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -45,7 +45,7 @@ def vector_io_faiss() -> ProviderFixture: Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig( + config=FaissVectorIOConfig( kvstore=SqliteKVStoreConfig(db_path=temp_file.name).model_dump(), ).model_dump(), ) @@ -76,7 +76,7 @@ def vector_io_pgvector() -> ProviderFixture: Provider( provider_id="pgvector", provider_type="remote::pgvector", - config=PGVectorConfig( + config=PGVectorVectorIOConfig( host=os.getenv("PGVECTOR_HOST", "localhost"), port=os.getenv("PGVECTOR_PORT", 5432), db=get_env_or_fail("PGVECTOR_DB"), @@ -95,7 +95,7 @@ def vector_io_weaviate() -> ProviderFixture: Provider( provider_id="weaviate", provider_type="remote::weaviate", - config=WeaviateConfig().model_dump(), + config=WeaviateVectorIOConfig().model_dump(), ) ], provider_data=dict( @@ -109,12 +109,12 @@ def vector_io_weaviate() -> ProviderFixture: def vector_io_chroma() -> ProviderFixture: url = os.getenv("CHROMA_URL") if url: - config = ChromaRemoteImplConfig(url=url) + config = ChromaVectorIOConfig(url=url) provider_type = "remote::chromadb" else: if not os.getenv("CHROMA_DB_PATH"): raise ValueError("CHROMA_DB_PATH or CHROMA_URL must be set") - config = ChromaInlineImplConfig(db_path=os.getenv("CHROMA_DB_PATH")) + config = InlineChromaVectorIOConfig(db_path=os.getenv("CHROMA_DB_PATH")) provider_type = "inline::chromadb" return ProviderFixture( providers=[ diff --git a/llama_stack/templates/bedrock/bedrock.py b/llama_stack/templates/bedrock/bedrock.py index 0c8259285..af1d48b7f 100644 --- a/llama_stack/templates/bedrock/bedrock.py +++ b/llama_stack/templates/bedrock/bedrock.py @@ -10,7 +10,7 @@ from llama_models.sku_list import all_registered_models from llama_stack.apis.models import ModelInput from llama_stack.distribution.datatypes import Provider, ToolGroupInput -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.bedrock.bedrock import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -37,7 +37,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} diff --git a/llama_stack/templates/cerebras/cerebras.py b/llama_stack/templates/cerebras/cerebras.py index 2dfae04f8..870240feb 100644 --- a/llama_stack/templates/cerebras/cerebras.py +++ b/llama_stack/templates/cerebras/cerebras.py @@ -13,7 +13,7 @@ from llama_stack.distribution.datatypes import ModelInput, Provider, ToolGroupIn from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -69,7 +69,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) default_tool_groups = [ ToolGroupInput( diff --git a/llama_stack/templates/fireworks/fireworks.py b/llama_stack/templates/fireworks/fireworks.py index ec350010b..e2e2ca99c 100644 --- a/llama_stack/templates/fireworks/fireworks.py +++ b/llama_stack/templates/fireworks/fireworks.py @@ -18,7 +18,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -58,7 +58,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) core_model_to_hf_repo = {m.descriptor(): m.huggingface_repo for m in all_registered_models()} diff --git a/llama_stack/templates/hf-endpoint/hf_endpoint.py b/llama_stack/templates/hf-endpoint/hf_endpoint.py index 4533fd95b..62584929c 100644 --- a/llama_stack/templates/hf-endpoint/hf_endpoint.py +++ b/llama_stack/templates/hf-endpoint/hf_endpoint.py @@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.tgi import InferenceEndpointImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -51,7 +51,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) inference_model = ModelInput( diff --git a/llama_stack/templates/hf-serverless/hf_serverless.py b/llama_stack/templates/hf-serverless/hf_serverless.py index 8438de7a5..46efb6f0b 100644 --- a/llama_stack/templates/hf-serverless/hf_serverless.py +++ b/llama_stack/templates/hf-serverless/hf_serverless.py @@ -14,7 +14,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.tgi import InferenceAPIImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -52,7 +52,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) inference_model = ModelInput( diff --git a/llama_stack/templates/meta-reference-gpu/meta_reference.py b/llama_stack/templates/meta-reference-gpu/meta_reference.py index a3f82b0c8..9bff981d1 100644 --- a/llama_stack/templates/meta-reference-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-gpu/meta_reference.py @@ -19,7 +19,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -58,7 +58,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) inference_model = ModelInput( diff --git a/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py b/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py index 8c2a6ec9f..fca15fcc5 100644 --- a/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py +++ b/llama_stack/templates/meta-reference-quantized-gpu/meta_reference.py @@ -14,7 +14,7 @@ from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -67,7 +67,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) inference_model = ModelInput( diff --git a/llama_stack/templates/ollama/ollama.py b/llama_stack/templates/ollama/ollama.py index a762e757a..f3383cd5a 100644 --- a/llama_stack/templates/ollama/ollama.py +++ b/llama_stack/templates/ollama/ollama.py @@ -16,7 +16,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.inline.vector_io.sqlite_vec.config import SQLiteVectorIOConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -53,7 +53,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider_faiss = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) vector_io_provider_sqlite = Provider( provider_id="sqlite_vec", diff --git a/llama_stack/templates/remote-vllm/vllm.py b/llama_stack/templates/remote-vllm/vllm.py index 6c835ef86..40a2d541d 100644 --- a/llama_stack/templates/remote-vllm/vllm.py +++ b/llama_stack/templates/remote-vllm/vllm.py @@ -16,7 +16,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.vllm import VLLMInferenceAdapterConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -55,7 +55,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) inference_model = ModelInput( diff --git a/llama_stack/templates/tgi/tgi.py b/llama_stack/templates/tgi/tgi.py index e49c98d72..71718a93d 100644 --- a/llama_stack/templates/tgi/tgi.py +++ b/llama_stack/templates/tgi/tgi.py @@ -16,7 +16,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.tgi import TGIImplConfig from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -55,7 +55,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) inference_model = ModelInput( diff --git a/llama_stack/templates/together/together.py b/llama_stack/templates/together/together.py index b7ac130ed..9ec5b38ba 100644 --- a/llama_stack/templates/together/together.py +++ b/llama_stack/templates/together/together.py @@ -18,7 +18,7 @@ from llama_stack.distribution.datatypes import ( from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.providers.remote.inference.together import TogetherImplConfig from llama_stack.providers.remote.inference.together.together import MODEL_ALIASES from llama_stack.templates.template import DistributionTemplate, RunConfigSettings @@ -51,7 +51,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) embedding_provider = Provider( provider_id="sentence-transformers", diff --git a/llama_stack/templates/vllm-gpu/vllm.py b/llama_stack/templates/vllm-gpu/vllm.py index 54ebd2d41..31900687b 100644 --- a/llama_stack/templates/vllm-gpu/vllm.py +++ b/llama_stack/templates/vllm-gpu/vllm.py @@ -10,7 +10,7 @@ from llama_stack.providers.inline.inference.sentence_transformers import ( SentenceTransformersInferenceConfig, ) from llama_stack.providers.inline.inference.vllm import VLLMConfig -from llama_stack.providers.inline.vector_io.faiss.config import FaissImplConfig +from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig from llama_stack.templates.template import ( DistributionTemplate, RunConfigSettings, @@ -46,7 +46,7 @@ def get_distribution_template() -> DistributionTemplate: vector_io_provider = Provider( provider_id="faiss", provider_type="inline::faiss", - config=FaissImplConfig.sample_run_config(f"distributions/{name}"), + config=FaissVectorIOConfig.sample_run_config(f"distributions/{name}"), ) embedding_provider = Provider( provider_id="sentence-transformers", From efdd60014d9a0adc3999b37cdbb4d3f931d19a14 Mon Sep 17 00:00:00 2001 From: Yuan Tang Date: Thu, 13 Feb 2025 13:44:57 -0500 Subject: [PATCH 27/27] test: Enable logprobs top_k tests for remote::vllm (#1080) top_k supported was added in https://github.com/meta-llama/llama-stack/pull/1074. The tests should be enabled as well. Verified that tests pass for remote::vllm: ``` LLAMA_STACK_BASE_URL=http://localhost:5003 pytest -v tests/client-sdk/inference/test_text_inference.py -k " test_completion_log_probs_non_streaming or test_completion_log_probs_streaming" ================================================================ test session starts ================================================================ platform linux -- Python 3.10.16, pytest-8.3.4, pluggy-1.5.0 -- /home/yutang/.conda/envs/distribution-myenv/bin/python3.10 cachedir: .pytest_cache rootdir: /home/yutang/repos/llama-stack configfile: pyproject.toml plugins: anyio-4.8.0 collected 14 items / 12 deselected / 2 selected tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_non_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [ 50%] tests/client-sdk/inference/test_text_inference.py::test_completion_log_probs_streaming[meta-llama/Llama-3.1-8B-Instruct] PASSED [100%] =================================================== 2 passed, 12 deselected, 1 warning in 10.03s ==================================================== ``` Signed-off-by: Yuan Tang --- tests/client-sdk/inference/test_text_inference.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/tests/client-sdk/inference/test_text_inference.py b/tests/client-sdk/inference/test_text_inference.py index 206629602..c931ca255 100644 --- a/tests/client-sdk/inference/test_text_inference.py +++ b/tests/client-sdk/inference/test_text_inference.py @@ -14,13 +14,7 @@ PROVIDER_TOOL_PROMPT_FORMAT = { "remote::vllm": "json", } -PROVIDER_LOGPROBS_TOP_K = set( - { - "remote::together", - "remote::fireworks", - # "remote:vllm" - } -) +PROVIDER_LOGPROBS_TOP_K = {"remote::together", "remote::fireworks", "remote::vllm"} @pytest.fixture(scope="session")